diff --git a/jlx/src/main.cpp b/jlx/src/main.cpp index 1a744a3..66a5812 100644 --- a/jlx/src/main.cpp +++ b/jlx/src/main.cpp @@ -63,5 +63,15 @@ int main(int argc, char** argv) { std::cout << "Parsed " << rt->statements.size() << " statements" << std::endl; + auto type_checker = jlx::type_checker(rt->statements.begin(), rt->statements.end()); + try { + type_checker.check(); + } catch (jlx::type_error& err) { + std::cerr << "Type error: " << err.what() << std::endl; + return 1; + } + + std::cout << "Type checking OK." << std::endl; + return 0; } diff --git a/libjlx/modules/ast.cppm b/libjlx/modules/ast.cppm index dbdae23..83b579c 100644 --- a/libjlx/modules/ast.cppm +++ b/libjlx/modules/ast.cppm @@ -17,8 +17,8 @@ namespace jlx { Block, FunctionDeclaration, VariableDeclaration, - LiteralValue, IfStatement, + ReturnStatement, }; export enum expression_type { @@ -52,7 +52,7 @@ namespace jlx { } - explicit root_statement(std::vector> statements, const token& t) : statement(Root, t), statements(std::move(statements)) { + root_statement(std::vector> statements, const token& t) : statement(Root, t), statements(std::move(statements)) { } @@ -77,11 +77,25 @@ namespace jlx { expression_type et = EtInvalid; std::optional evaluated_type; - virtual std::unique_ptr clone() const = 0; + [[nodiscard]] virtual std::unique_ptr clone() const = 0; ~expression() override = default; }; + export struct return_statement : public statement { + explicit return_statement(const token& t) : statement(ReturnStatement, t) { + + } + + return_statement(std::unique_ptr expression, const token& t) : statement(ReturnStatement, t), expression(std::move(expression)) { + + } + + std::unique_ptr expression; + + ~return_statement() override = default; + }; + export struct block : public statement { explicit block(const token& t) : statement(Block, t) { @@ -455,6 +469,13 @@ namespace jlx { } } fail_invalid_token(*current); + break; + case Keyword: + if (previous != nullptr) { + return previous; + } + fail_invalid_token(*current); + break; default: fail_invalid_token(*current); } @@ -531,6 +552,23 @@ namespace jlx { return std::make_unique(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block)); } + std::unique_ptr parse_return_statement() { + auto start = *current; + + if (current->type != Keyword || current->content != "return") { + fail_invalid_token(*current); + } + + next(false); + + std::unique_ptr expr = nullptr; + if (current != last) { + expr = parse_expression(); + } + + return std::make_unique(std::move(expr), start); + } + std::string parse_type(){ if (current->type != Identifier) { fail_invalid_token(*current); @@ -551,6 +589,8 @@ namespace jlx { return parse_if_statement(); } else if (current->content == "fun" && top_level) { return parse_function(); + } else if (current->content == "return") { + return parse_return_statement(); } } diff --git a/libjlx/modules/tokenizer.cppm b/libjlx/modules/tokenizer.cppm index 117c226..93fff78 100644 --- a/libjlx/modules/tokenizer.cppm +++ b/libjlx/modules/tokenizer.cppm @@ -80,13 +80,14 @@ namespace jlx { export class tokenizer { source_stream source; - static constexpr std::array keywords = {{ + static constexpr std::array keywords = {{ "if", "else", "fun", //"struct", "let", - "var" + "var", + "return" }}; static constexpr std::array punctuations = {{ diff --git a/libjlx/modules/type_checker.cppm b/libjlx/modules/type_checker.cppm index 0a0036b..09bd756 100644 --- a/libjlx/modules/type_checker.cppm +++ b/libjlx/modules/type_checker.cppm @@ -1,5 +1,6 @@ module; +#include #include #include #include @@ -19,18 +20,20 @@ namespace jlx { export class type_checker_runtime_context { public: - virtual std::string get_identifier_type(const std::string_view&) const = 0; - virtual runtime_function get_function(const std::string_view&) const = 0; + virtual ~type_checker_runtime_context() = default; + + virtual std::string_view get_identifier_type(const std::string_view&) const = 0; + virtual const runtime_function& get_function(const std::string_view&) const = 0; }; - export class type_error : std::runtime_error { + export class type_error : public std::runtime_error { public: explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) { } }; - export class type_checker { + class type_checker_utils { static std::size_t size_of(const std::string_view& type) { if (type == "u8" || type == "i8" || type == "char") { return 1; @@ -57,7 +60,7 @@ namespace jlx { return type == "f32" || type == "f64" || type == "comptime_float"; } - + public: static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) { if (target_type == source_type) { return true; @@ -145,10 +148,10 @@ namespace jlx { return false; } - static void evaluate_expression_type(std::unique_ptr& expr, const type_checker_runtime_context& ctx) { - const auto& content = expr->t.content; - if (expr->et == EtLiteralValue) { - auto* literal_expression = dynamic_cast(expr.get()); + static void evaluate_expression_type(expression& expr, const type_checker_runtime_context& ctx) { + const auto& content = expr.t.content; + if (expr.et == EtLiteralValue) { + auto* literal_expression = dynamic_cast(&expr); if (literal_expression == nullptr) { throw std::runtime_error("Internal type checker/AST error"); } @@ -157,22 +160,25 @@ namespace jlx { switch(literal_expression->t.type) { case token_type::Number: if (std::find(content.begin(), content.end(), '.') != content.end()) { - type = "comptime.float"; + type = "comptime_float"; } else { - type = "comptime.int"; + type = "comptime_int"; } break; case token_type::String: - type = "comptime.string"; + type = "string"; + break; + case token_type::Boolean: + type = "boolean"; break; default: throw std::runtime_error("Unsupported token for literal value expression"); } - expr->evaluated_type = type; + expr.evaluated_type = type; return; - } else if (expr->et == EtSingleValueOperation) { - auto* single_op = dynamic_cast(expr.get()); + } else if (expr.et == EtSingleValueOperation) { + auto* single_op = dynamic_cast(&expr); if (single_op == nullptr) { throw std::runtime_error("Internal type checker/AST error"); @@ -184,20 +190,20 @@ namespace jlx { auto& opr = single_op->operand; if (opr->evaluated_type == std::nullopt) { - evaluate_expression_type(opr, ctx); + evaluate_expression_type(*opr, ctx); } auto& op = single_op->operator_token.content; auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void"; if (op == "+") { if (type_is_integer(base_type) || type_is_float(base_type)) { - expr->evaluated_type = std::move(base_type); + expr.evaluated_type = std::move(base_type); return; } throw type_error(opr->t, base_type); } else if(op == "-") { if (type_is_integer(base_type) || type_is_float(base_type)) { - expr->evaluated_type = std::move(base_type); + expr.evaluated_type = std::move(base_type); //TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'? return; } @@ -207,13 +213,13 @@ namespace jlx { throw type_error(opr->t, base_type); } - expr->evaluated_type = "boolean"; + expr.evaluated_type = "boolean"; return; } else { throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation"); } - } else if (expr->et == EtDualValueOperation) { - auto* dual_op = dynamic_cast(expr.get()); + } else if (expr.et == EtDualValueOperation) { + auto* dual_op = dynamic_cast(&expr); if (dual_op == nullptr) { throw std::runtime_error("Internal type checker/AST error"); @@ -226,11 +232,11 @@ namespace jlx { auto& expr0 = dual_op->first_operand; auto& expr1 = dual_op->second_operand; if (!expr0->evaluated_type.has_value()) { - evaluate_expression_type(expr0, ctx); + evaluate_expression_type(*expr0, ctx); } if (!expr1->evaluated_type.has_value()) { - evaluate_expression_type(expr1, ctx); + evaluate_expression_type(*expr1, ctx); } auto& op = dual_op->operator_token.content; @@ -243,27 +249,27 @@ namespace jlx { } if (type0 == "string") { - expr->evaluated_type = "string"; + expr.evaluated_type = "string"; return; } if (type0 == type1) { - expr->evaluated_type = std::move(type0); + expr.evaluated_type = std::move(type0); return; } if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { - expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; + expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; return; } if (type_is_float(type0) && type1 == "comptime_int") { - expr->evaluated_type = type0; + expr.evaluated_type = type0; return; } if (type_is_float(type1) && type0 == "comptime_int") { - expr->evaluated_type = type1; + expr.evaluated_type = type1; return; } @@ -274,25 +280,25 @@ namespace jlx { } if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { - expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; + expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; return; } if (type_is_float(type0) && type1 == "comptime_int") { - expr->evaluated_type = type0; + expr.evaluated_type = type0; return; } if (type_is_float(type1) && type0 == "comptime_int") { - expr->evaluated_type = type1; + expr.evaluated_type = type1; return; } throw std::runtime_error("Unsupported type combination for operator '+'"); } else if (op == "==" || op == "!=") { if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) { - expr->evaluated_type = "boolean"; + expr.evaluated_type = "boolean"; return; } @@ -310,11 +316,11 @@ namespace jlx { throw type_error(expr1->t, type1); } - expr->evaluated_type = "boolean"; + expr.evaluated_type = "boolean"; return; } - } else if (expr->et == EtFunctionCall) { - auto* call = dynamic_cast(expr.get()); + } else if (expr.et == EtFunctionCall) { + auto* call = dynamic_cast(&expr); if (call == nullptr) { throw std::runtime_error("Internal type checker/AST error"); @@ -332,7 +338,7 @@ namespace jlx { auto& arg = call->arguments[i]; if (!arg->evaluated_type.has_value()) { - evaluate_expression_type(arg, ctx); + evaluate_expression_type(*arg, ctx); } auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void"; @@ -341,20 +347,222 @@ namespace jlx { } } - expr->evaluated_type = fn.return_type; + expr.evaluated_type = fn.return_type; return; - } else if (expr->et == EtIdentifier) { - auto* ident_expr = dynamic_cast(expr.get()); + } else if (expr.et == EtIdentifier) { + auto* ident_expr = dynamic_cast(&expr); if (ident_expr == nullptr) { throw std::runtime_error("Internal type checker/AST error"); } - expr->evaluated_type = ctx.get_identifier_type(ident_expr->name); + expr.evaluated_type = ctx.get_identifier_type(ident_expr->name); return; } throw std::runtime_error("Cannot type-check expression: Unknown expression"); } }; + export template + concept statement_iterator = requires() { + requires std::same_as>; + std::bidirectional_iterator; + }; + + struct type_checker_scope { + std::unordered_map variables; + std::unordered_map functions; + std::optional intended_return_type; + }; + + export template E> + class type_checker : public type_checker_runtime_context { + private: + T current; + E end; + std::vector scopes; + + void check_statement(statement& s, bool skip_block_scope = false) { + switch (s.type) { + case Root: { + auto* root = dynamic_cast(&s); + + for (auto& st : root->statements) { + check_statement(*st); + } + } + break; + case Block: { + auto* blk = dynamic_cast(&s); + + if (!skip_block_scope) { + scopes.emplace_back(type_checker_scope{}); + } + + for (auto& st : blk->statements) { + check_statement(*st); + } + + if (!skip_block_scope) { + scopes.pop_back(); + } + } + break; + case IfStatement: { + auto* if_stmt = dynamic_cast(&s); + + auto& cond = if_stmt->condition; + if (!cond->evaluated_type.has_value()) { + type_checker_utils::evaluate_expression_type(*if_stmt->condition, *this); + } + + auto tp = cond->evaluated_type.has_value() ? cond->evaluated_type.value() : "void"; + + if (tp != "boolean") { + throw type_error(cond->t, tp); + } + + check_statement(*if_stmt->block); + } + break; + case VariableDeclaration: { + auto* decl = dynamic_cast(&s); + + if (has_variable(decl->name)) { + throw std::runtime_error("Duplicate variable name detected"); + } + + std::optional init_type = std::nullopt; + if (decl->initial_expression != nullptr) { + if (!decl->initial_expression->evaluated_type.has_value()) { + type_checker_utils::evaluate_expression_type(*decl->initial_expression, *this); + } + init_type = decl->initial_expression->evaluated_type; + } + + auto type = decl->type; + + if (init_type.has_value()) { + if (!type.has_value()) { + type = init_type; + } else if (!type_checker_utils::is_convertible_to(init_type.value(), type.value())) { + throw type_error(decl->initial_expression->t, init_type.value()); + } + } + + if (!type.has_value()) { + throw std::runtime_error("Cannot infer variable declaration type"); + } + + scopes[scopes.size() - 1].variables.insert_or_assign(decl->name, type.value()); + } + break; + case FunctionDeclaration: { + auto* decl = dynamic_cast(&s); + + if (has_function(decl->name)) { + throw std::runtime_error("Duplicate function name detected"); + } + + if (!decl->return_type.has_value()) { + decl->return_type = "void"; + } + + runtime_function f; + f.return_type = decl->return_type.value(); + + for (auto& a : decl->parameters) { + f.arguments.emplace_back(a.type); + } + + scopes[scopes.size() - 1].functions.insert_or_assign(decl->name, f); + + auto& function_scope = scopes.emplace_back(type_checker_scope{}); + function_scope.intended_return_type = decl->return_type; + + check_statement(*decl->body, true); + + scopes.pop_back(); + } + break; + case Expression: { + auto* expr = dynamic_cast(&s); + + type_checker_utils::evaluate_expression_type(*expr, *this); + } + break; + case ReturnStatement: { + auto* rtr = dynamic_cast(&s); + + auto& scope = scopes.back(); + + if (rtr->expression != nullptr && !rtr->expression->evaluated_type.has_value()) { + type_checker_utils::evaluate_expression_type(*rtr->expression, *this); + } + + auto rvl = scope.intended_return_type.has_value() ? scope.intended_return_type.value() : "void"; + auto exvl = rtr->expression != nullptr && rtr->expression->evaluated_type.has_value() ? rtr->expression->evaluated_type.value() : "void"; + + if (!type_checker_utils::is_convertible_to(exvl, rvl)) { + throw type_error(rtr->expression->t, exvl); + } + } + break; + } + } + + [[nodiscard]] bool has_variable(const std::string_view& name) const { + auto& scope = scopes[scopes.size() - 1]; + auto itx = scope.variables.find(std::string(name)); + if (itx != scope.variables.end()) { + return true; + } + + return false; + } + + [[nodiscard]] bool has_function(const std::string_view& name) const { + for (const auto& scope : std::ranges::reverse_view(scopes)) { + auto itx = scope.functions.find(std::string(name)); + if (itx != scope.functions.end()) { + return true; + } + } + + return false; + } + public: + type_checker(T start, E end) : current(start), end(end) { + scopes.emplace_back(); + } + + void check() { + while (current != end) { + check_statement(**current); + + ++current; + } + } + + [[nodiscard]] std::string_view get_identifier_type(const std::string_view &name) const override { + for (const auto & scope : std::ranges::reverse_view(scopes)) { + auto itx = scope.variables.find(std::string(name)); + if (itx != scope.variables.end()) { + return itx->second; + } + } + + throw std::runtime_error("Missing identifier"); + } + + [[nodiscard]] const runtime_function& get_function(const std::string_view &name) const override { + for (const auto& scope : std::ranges::reverse_view(scopes)) { + if (auto itx = scope.functions.find(std::string(name)); itx != scope.functions.end()) { + return itx->second; + } + } + + throw std::runtime_error("Missing function"); + } + }; } \ No newline at end of file