From 50205e8185deed8476d101337bcfc7e156319ee9 Mon Sep 17 00:00:00 2001 From: John Stefanelli Date: Thu, 29 May 2025 15:39:31 +0200 Subject: [PATCH] [libjlx] Add type_checker module --- libjlx/CMakeLists.txt | 1 + libjlx/modules/ast.cppm | 185 ++++++++++++++-- libjlx/modules/main.cppm | 1 + libjlx/modules/tokenizer.cppm | 20 +- libjlx/modules/type_checker.cppm | 360 +++++++++++++++++++++++++++++++ 5 files changed, 543 insertions(+), 24 deletions(-) create mode 100644 libjlx/modules/type_checker.cppm diff --git a/libjlx/CMakeLists.txt b/libjlx/CMakeLists.txt index 8e550a1..da8ce16 100644 --- a/libjlx/CMakeLists.txt +++ b/libjlx/CMakeLists.txt @@ -6,6 +6,7 @@ target_sources(libjlx PUBLIC FILE_SET libjlx_modules TYPE CXX_MODULES FILES "${CMAKE_CURRENT_SOURCE_DIR}/modules/tokenizer.cppm" "${CMAKE_CURRENT_SOURCE_DIR}/modules/ast.cppm" "${CMAKE_CURRENT_SOURCE_DIR}/modules/utils.cppm" + "${CMAKE_CURRENT_SOURCE_DIR}/modules/type_checker.cppm" ) target_compile_options(libjlx PRIVATE $,/W4 /WX,-Wall -Wextra -Werror>) target_include_directories(libjlx PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") diff --git a/libjlx/modules/ast.cppm b/libjlx/modules/ast.cppm index 585d2b5..178bcba 100644 --- a/libjlx/modules/ast.cppm +++ b/libjlx/modules/ast.cppm @@ -2,9 +2,9 @@ module; #include #include -#include #include #include +#include export module jlx:ast; @@ -21,11 +21,13 @@ namespace jlx { IfStatement, }; - export enum class literal_value_type { - Boolean, - Numeric, - String, - Character, + export enum expression_type { + EtInvalid = 0, + EtLiteralValue, + EtSingleValueOperation, + EtDualValueOperation, + EtFunctionCall, + EtIdentifier }; export template @@ -36,8 +38,9 @@ namespace jlx { struct statement { ast_type type; + token t; - statement(ast_type type) : type(type) { + statement(ast_type type, const token& t) : type(type), t(t) { } @@ -45,11 +48,11 @@ namespace jlx { }; struct root_statement : public statement { - root_statement() : statement(Root) { + explicit root_statement(const token& t) : statement(Root, t) { } - root_statement(std::vector> statements) : statement(Root), statements(std::move(statements)) { + explicit root_statement(std::vector> statements, const token& t) : statement(Root, t), statements(std::move(statements)) { } @@ -59,19 +62,32 @@ namespace jlx { }; struct expression : public statement { - expression() : statement(Expression) { + explicit expression(const token& t) : statement(Expression, t) { } + expression(const expression&) = default; + expression& operator=(const expression& other) { + et = EtInvalid; + t = other.t; + evaluated_type = other.evaluated_type; + return *this; + } + + expression_type et = EtInvalid; + std::optional evaluated_type; + + virtual std::unique_ptr clone() const = 0; + ~expression() override = default; }; struct block : public statement { - block() : statement(Block) { + explicit block(const token& t) : statement(Block, t) { } - block(std::vector> statements) : statement(Block), statements(std::move(statements)) { + explicit block(const token& t, std::vector> statements) : statement(Block, t), statements(std::move(statements)) { } @@ -86,12 +102,12 @@ namespace jlx { }; struct function_declaration : public statement { - function_declaration() : statement(FunctionDeclaration) { + explicit function_declaration(const token& t) : statement(FunctionDeclaration, t) { } - function_declaration(std::string name, std::vector parameters, std::optional return_type, std::unique_ptr body) : - statement(FunctionDeclaration), name(std::move(name)), parameters(std::move(parameters)), return_type(std::move(return_type)), body(std::move(body)) { + function_declaration(const token& t, std::string name, std::vector parameters, std::optional return_type, std::unique_ptr body) : + statement(FunctionDeclaration, t), name(std::move(name)), parameters(std::move(parameters)), return_type(std::move(return_type)), body(std::move(body)) { } @@ -105,11 +121,11 @@ namespace jlx { }; struct variable_declaration : public statement { - variable_declaration() : statement(VariableDeclaration) { + variable_declaration(const token& t) : statement(VariableDeclaration, t) { } - bool constant; + bool constant = true; std::string name; std::optional type; std::unique_ptr initial_expression; @@ -118,7 +134,7 @@ namespace jlx { }; struct if_statement : public statement { - if_statement() : statement(IfStatement) { + if_statement(const token& t) : statement(IfStatement, t) { } @@ -126,6 +142,121 @@ namespace jlx { std::unique_ptr block; }; + struct literal_value : public expression { + explicit literal_value(const token& t): expression(t) { + et = EtLiteralValue; + } + + literal_value(const literal_value&) = default; + literal_value& operator=(const literal_value&) = default; + + [[nodiscard]] std::unique_ptr clone() const override { + return std::make_unique(*this); + } + }; + + struct single_operation : public expression { + single_operation(const token& t, std::unique_ptr operand, const token& operator_token) : expression(t), + operand(std::move(operand)), operator_token(operator_token) { + et = EtSingleValueOperation; + } + + single_operation(const single_operation& other) : expression(other.t), operand(other.operand->clone()), operator_token(other.operator_token) { + et = EtSingleValueOperation; + operand = other.operand->clone(); + } + + [[nodiscard]] std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + single_operation& operator=(const single_operation& other){ + et = EtSingleValueOperation; + operand = other.operand->clone(); + operator_token = other.operator_token; + return *this; + } + + std::unique_ptr operand; + token operator_token; + }; + + struct dual_operation : public expression { + dual_operation(const token& t, std::unique_ptr first_operand, std::unique_ptr second_operand, const token& operator_token) : + expression(t), first_operand(std::move(first_operand)), second_operand(std::move(second_operand)), operator_token(operator_token) { + et = EtDualValueOperation; + } + + dual_operation(const dual_operation& other) : + expression(other.t), first_operand(other.first_operand->clone()), second_operand(other.second_operand->clone()), operator_token(other.operator_token) { + et = EtDualValueOperation; + } + + dual_operation& operator=(const dual_operation* other) { + et = EtDualValueOperation; + first_operand = other->first_operand->clone(); + second_operand = other->second_operand->clone(); + operator_token = other->operator_token;\ + return *this; + } + + [[nodiscard]] std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + std::unique_ptr first_operand; + std::unique_ptr second_operand; + token operator_token; + }; + + struct function_call : public expression { + function_call(const token& t, std::string function_name, std::vector> arguments) : + expression(t), function_name(std::move(function_name)), arguments(std::move(arguments)) { + et = EtFunctionCall; + } + + function_call(const function_call& other) : expression(other.t) { + et = EtFunctionCall; + function_name = other.function_name; + arguments.reserve(other.arguments.size()); + for(auto& arg : other.arguments) { + arguments.emplace_back(arg->clone()); + } + } + + function_call& operator=(const function_call& other) { + et = EtFunctionCall; + function_name = other.function_name; + arguments.reserve(other.arguments.size()); + for(auto& arg : other.arguments) { + arguments.emplace_back(arg->clone()); + } + return *this; + } + + [[nodiscard]] std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + std::string function_name; + std::vector> arguments; + }; + + struct identifier_expression : public expression { + identifier_expression(const token& t, std::string name) : expression(t), name(std::move(name)) { + et = EtIdentifier; + } + + identifier_expression(const identifier_expression&) = default; + identifier_expression& operator=(const identifier_expression&) = default; + + [[nodiscard]] std::unique_ptr clone() const override { + return std::make_unique(*this); + } + + std::string name; + }; + export template E> class parser { T current; @@ -165,6 +296,8 @@ namespace jlx { fail_invalid_token(*current); } + auto start = *current; + std::string name; std::optional type = std::nullopt; std::unique_ptr starting_value = nullptr; @@ -207,7 +340,7 @@ namespace jlx { starting_value = parse_expression(); } - auto var = std::make_unique(); + auto var = std::make_unique(start); var->constant = constant; var->name = std::move(name); var->type = std::move(type); @@ -221,6 +354,8 @@ namespace jlx { fail_invalid_token(*current); } + auto start = *current; + next(); if (current->type != Punctuation || current->content != "(") { @@ -239,7 +374,7 @@ namespace jlx { auto block = parse_block(); - auto statement = std::make_unique(); + auto statement = std::make_unique(start); statement->block = std::move(block); statement->condition = std::move(expr); @@ -251,6 +386,8 @@ namespace jlx { } std::unique_ptr parse_function() { + auto start = *current; + if (current->type != Keyword || current->content != "fun") { fail_invalid_token(*current); } @@ -317,7 +454,7 @@ namespace jlx { auto block = parse_block(); - return std::make_unique(std::move(function_name), std::move(params), std::move(return_type), std::move(block)); + return std::make_unique(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block)); } std::string parse_type(){ @@ -352,6 +489,10 @@ namespace jlx { } std::unique_ptr parse() { + if (current == last) { + return nullptr; + } + auto start = *current; std::vector> top_level_statements; while(current != last) { @@ -362,7 +503,7 @@ namespace jlx { top_level_statements.push_back(std::move(s)); } - return std::make_unique(std::move(top_level_statements)); + return std::make_unique(std::move(top_level_statements), start); } }; } diff --git a/libjlx/modules/main.cppm b/libjlx/modules/main.cppm index 5566086..4bd9954 100644 --- a/libjlx/modules/main.cppm +++ b/libjlx/modules/main.cppm @@ -4,3 +4,4 @@ export module jlx; export import :source_stream; export import :tokenizer; export import :ast; +export import :type_checker; diff --git a/libjlx/modules/tokenizer.cppm b/libjlx/modules/tokenizer.cppm index f245df6..141b9ee 100644 --- a/libjlx/modules/tokenizer.cppm +++ b/libjlx/modules/tokenizer.cppm @@ -43,9 +43,18 @@ namespace jlx { export struct token { token_type type; + std::string source_file; std::string content; std::size_t line; std::size_t col; + + token(token_type type, const std::string& source_file, const std::string& content, std::size_t line, std::size_t col) : + type(type), source_file(source_file), content(content), line(line), col(col) { + + } + + token(const token&) = default; + token& operator=(const token&) = default; }; export constexpr std::string token_to_string(const token& t) { @@ -87,7 +96,7 @@ namespace jlx { ';' }}; - static constexpr std::array operators = {{ + static constexpr std::array operators = {{ "=", "+", "-", @@ -99,7 +108,8 @@ namespace jlx { "<=", ">=", ">", - "<" + "<", + "!" }}; void skip_whitespace() { @@ -155,6 +165,7 @@ namespace jlx { return { token_type::String, + "mono_src", buffer.str(), start_line, start_col @@ -189,6 +200,7 @@ namespace jlx { return token { token_type::Number, + "mono_src", buffer.str(), start_line, start_col @@ -227,6 +239,7 @@ namespace jlx { if (std::find(keywords.begin(), keywords.end(), word) != keywords.end()) { return token { token_type::Keyword, + "mono_src", word, start_line, start_col @@ -234,6 +247,7 @@ namespace jlx { } else { return token { token_type::Identifier, + "mono_src", word, start_line, start_col @@ -256,6 +270,7 @@ namespace jlx { source.next(); return token { token_type::Punctuation, + "mono_src", std::string() + val, line, col @@ -293,6 +308,7 @@ namespace jlx { if (std::find(operators.begin(), operators.end(), word) != operators.end()) { return token { token_type::Operator, + "mono_src", word, line, col diff --git a/libjlx/modules/type_checker.cppm b/libjlx/modules/type_checker.cppm new file mode 100644 index 0000000..0a0036b --- /dev/null +++ b/libjlx/modules/type_checker.cppm @@ -0,0 +1,360 @@ +module; + +#include +#include +#include +#include +#include + +export module jlx:type_checker; + +import :tokenizer; +import :ast; + +namespace jlx { + export struct runtime_function { + std::string_view return_type; + std::vector arguments; + }; + + export class type_checker_runtime_context { + public: + virtual std::string get_identifier_type(const std::string_view&) const = 0; + virtual runtime_function get_function(const std::string_view&) const = 0; + }; + + export class type_error : std::runtime_error { + public: + explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) { + + } + }; + + export class type_checker { + static std::size_t size_of(const std::string_view& type) { + if (type == "u8" || type == "i8" || type == "char") { + return 1; + } else if (type == "u16" || type == "i16") { + return 16; + } else if (type == "u32" || type == "i32" || type == "f32") { + return 32; + } else if (type == "u64" || type == "i64" || type == "f64") { + return 64; + } else if (type == "comptime_float" || type == "comptime_int") { + return 0; //This is to make 'comptime_***' types be the lowest priority types for arithmetic operation type resolution. + //All "comptime_***" types will never appear in runtime + } + throw std::runtime_error("Cannot evaluate size of type"); + } + + static bool type_is_integer(const std::string_view& type) { + return type == "u8" || type == "u16" || type == "u32" || type == "u64" || + type == "i8" || type == "i16" || type == "i32" || type == "i64" || + type == "comptime_int"; + } + + static bool type_is_float(const std::string_view& type) { + return type == "f32" || type == "f64" || type == "comptime_float"; + } + + + static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) { + if (target_type == source_type) { + return true; + } + + if (target_type == "i8") { + return + source_type == "comptime_int"; + } + + if (target_type == "i16") { + return + source_type == "i8" || + source_type == "u8" || + source_type == "comptime_int"; + } + + if (target_type == "i32") { + return + source_type == "i8" || + source_type == "u8" || + source_type == "i16" || + source_type == "u16" || + source_type == "comptime_int"; + } + + if (target_type == "i64") { + return + source_type == "i8" || + source_type == "u8" || + source_type == "i16" || + source_type == "u16" || + source_type == "i32" || + source_type == "u32" || + source_type == "comptime_int"; + } + + if (target_type == "u8") { + return + source_type == "comptime_int"; + } + + if (target_type == "u16") { + return + source_type == "u8" || + source_type == "comptime_int"; + } + + if (target_type == "u32") { + return + source_type == "u8" || + source_type == "u16" || + source_type == "comptime_int"; + } + + if (target_type == "u64") { + return + source_type == "u8" || + source_type == "u16" || + source_type == "u32" || + source_type == "comptime_int"; + } + + if (target_type == "char") { + return + source_type == "u8"; + } + + if (target_type == "f32") { + return + source_type == "comptime_float"; + } + + if (target_type == "f64") { + return + source_type == "f32" || + source_type == "comptime_float"; + } + + if (target_type == "comptime_float") { + return + source_type == "comptime_int"; + } + + return false; + } + + static void evaluate_expression_type(std::unique_ptr& expr, const type_checker_runtime_context& ctx) { + const auto& content = expr->t.content; + if (expr->et == EtLiteralValue) { + auto* literal_expression = dynamic_cast(expr.get()); + if (literal_expression == nullptr) { + throw std::runtime_error("Internal type checker/AST error"); + } + + std::string type; + switch(literal_expression->t.type) { + case token_type::Number: + if (std::find(content.begin(), content.end(), '.') != content.end()) { + type = "comptime.float"; + } else { + type = "comptime.int"; + } + break; + case token_type::String: + type = "comptime.string"; + break; + default: + throw std::runtime_error("Unsupported token for literal value expression"); + } + + expr->evaluated_type = type; + return; + } else if (expr->et == EtSingleValueOperation) { + auto* single_op = dynamic_cast(expr.get()); + + if (single_op == nullptr) { + throw std::runtime_error("Internal type checker/AST error"); + } + + if (single_op->operator_token.type != token_type::Operator) { + throw std::runtime_error("Invalid operator token for operation expression"); + } + + auto& opr = single_op->operand; + if (opr->evaluated_type == std::nullopt) { + evaluate_expression_type(opr, ctx); + } + + auto& op = single_op->operator_token.content; + auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void"; + if (op == "+") { + if (type_is_integer(base_type) || type_is_float(base_type)) { + expr->evaluated_type = std::move(base_type); + return; + } + throw type_error(opr->t, base_type); + } else if(op == "-") { + if (type_is_integer(base_type) || type_is_float(base_type)) { + expr->evaluated_type = std::move(base_type); + //TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'? + return; + } + throw type_error(opr->t, base_type); + } else if (op == "!") { + if (base_type != "boolean") { + throw type_error(opr->t, base_type); + } + + expr->evaluated_type = "boolean"; + return; + } else { + throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation"); + } + } else if (expr->et == EtDualValueOperation) { + auto* dual_op = dynamic_cast(expr.get()); + + if (dual_op == nullptr) { + throw std::runtime_error("Internal type checker/AST error"); + } + + if (dual_op->operator_token.type != token_type::Operator) { + throw std::runtime_error("Invalid operator token for dual value expression"); + } + + auto& expr0 = dual_op->first_operand; + auto& expr1 = dual_op->second_operand; + if (!expr0->evaluated_type.has_value()) { + evaluate_expression_type(expr0, ctx); + } + + if (!expr1->evaluated_type.has_value()) { + evaluate_expression_type(expr1, ctx); + } + + auto& op = dual_op->operator_token.content; + auto type0 = expr0->evaluated_type.has_value() ? expr0->evaluated_type.value() : "void"; + auto type1 = expr1->evaluated_type.has_value() ? expr1->evaluated_type.value() : "void"; + + if (op == "+") { + if (!type_is_integer(type0) && !type_is_float(type0) && type0 != "string") { + throw type_error(expr0->t, type0); + } + + if (type0 == "string") { + expr->evaluated_type = "string"; + return; + } + + if (type0 == type1) { + expr->evaluated_type = std::move(type0); + return; + } + + if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { + expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; + return; + } + + if (type_is_float(type0) && type1 == "comptime_int") { + expr->evaluated_type = type0; + return; + } + + if (type_is_float(type1) && type0 == "comptime_int") { + expr->evaluated_type = type1; + return; + } + + throw std::runtime_error("Unsupported type combination for operator '+'"); + } else if (op == "-" || op == "*" || op == "/" || op == "%") { + if (!type_is_integer(type0) || type_is_float(type0)) { + throw type_error(expr0->t, type0); + } + + if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { + expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; + return; + } + + + if (type_is_float(type0) && type1 == "comptime_int") { + expr->evaluated_type = type0; + return; + } + + if (type_is_float(type1) && type0 == "comptime_int") { + expr->evaluated_type = type1; + return; + } + + throw std::runtime_error("Unsupported type combination for operator '+'"); + } else if (op == "==" || op == "!=") { + if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) { + expr->evaluated_type = "boolean"; + return; + } + + throw type_error(expr1->t, type1); + } else if (op == "<=" || op == ">=" || op == ">" || op == "<") { + if (!type_is_integer(type0) && !type_is_float(type0)) { + throw type_error(expr0->t, type0); + } + + if (!type_is_integer(type1) && !type_is_float(type1)) { + throw type_error(expr1->t, type1); + } + + if (!is_convertible_to(type0, type1) && !is_convertible_to(type1, type0)) { + throw type_error(expr1->t, type1); + } + + expr->evaluated_type = "boolean"; + return; + } + } else if (expr->et == EtFunctionCall) { + auto* call = dynamic_cast(expr.get()); + + if (call == nullptr) { + throw std::runtime_error("Internal type checker/AST error"); + } + + auto fn = ctx.get_function(call->function_name); + + if (fn.arguments.size() != call->arguments.size()) { + throw std::runtime_error("Invalid number of arguments for function call"); + } + + for(auto i = 0ULL; i < fn.arguments.size(); i++) { + auto& target = fn.arguments[i]; + + auto& arg = call->arguments[i]; + + if (!arg->evaluated_type.has_value()) { + evaluate_expression_type(arg, ctx); + } + + auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void"; + if (!is_convertible_to(src, target)) { + throw type_error(arg->t, src); + } + } + + expr->evaluated_type = fn.return_type; + return; + } else if (expr->et == EtIdentifier) { + auto* ident_expr = dynamic_cast(expr.get()); + + if (ident_expr == nullptr) { + throw std::runtime_error("Internal type checker/AST error"); + } + + expr->evaluated_type = ctx.get_identifier_type(ident_expr->name); + return; + } + + throw std::runtime_error("Cannot type-check expression: Unknown expression"); + } + }; +} \ No newline at end of file