[libjlx] Implement basic type_checking

This commit is contained in:
John Stefanelli 2025-05-30 10:02:23 +02:00
parent 07d2a25712
commit e3d912c6a5
Signed by: jstefanelli
GPG key ID: 60EDE2437640D2AA
4 changed files with 304 additions and 45 deletions

View file

@ -63,5 +63,15 @@ int main(int argc, char** argv) {
std::cout << "Parsed " << rt->statements.size() << " statements" << std::endl; std::cout << "Parsed " << rt->statements.size() << " statements" << std::endl;
auto type_checker = jlx::type_checker(rt->statements.begin(), rt->statements.end());
try {
type_checker.check();
} catch (jlx::type_error& err) {
std::cerr << "Type error: " << err.what() << std::endl;
return 1;
}
std::cout << "Type checking OK." << std::endl;
return 0; return 0;
} }

View file

@ -17,8 +17,8 @@ namespace jlx {
Block, Block,
FunctionDeclaration, FunctionDeclaration,
VariableDeclaration, VariableDeclaration,
LiteralValue,
IfStatement, IfStatement,
ReturnStatement,
}; };
export enum expression_type { export enum expression_type {
@ -52,7 +52,7 @@ namespace jlx {
} }
explicit root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) { root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) {
} }
@ -77,11 +77,25 @@ namespace jlx {
expression_type et = EtInvalid; expression_type et = EtInvalid;
std::optional<std::string> evaluated_type; std::optional<std::string> evaluated_type;
virtual std::unique_ptr<expression> clone() const = 0; [[nodiscard]] virtual std::unique_ptr<expression> clone() const = 0;
~expression() override = default; ~expression() override = default;
}; };
export struct return_statement : public statement {
explicit return_statement(const token& t) : statement(ReturnStatement, t) {
}
return_statement(std::unique_ptr<expression> expression, const token& t) : statement(ReturnStatement, t), expression(std::move(expression)) {
}
std::unique_ptr<expression> expression;
~return_statement() override = default;
};
export struct block : public statement { export struct block : public statement {
explicit block(const token& t) : statement(Block, t) { explicit block(const token& t) : statement(Block, t) {
@ -455,6 +469,13 @@ namespace jlx {
} }
} }
fail_invalid_token(*current); fail_invalid_token(*current);
break;
case Keyword:
if (previous != nullptr) {
return previous;
}
fail_invalid_token(*current);
break;
default: default:
fail_invalid_token(*current); fail_invalid_token(*current);
} }
@ -531,6 +552,23 @@ namespace jlx {
return std::make_unique<function_declaration>(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block)); return std::make_unique<function_declaration>(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block));
} }
std::unique_ptr<return_statement> parse_return_statement() {
auto start = *current;
if (current->type != Keyword || current->content != "return") {
fail_invalid_token(*current);
}
next(false);
std::unique_ptr<expression> expr = nullptr;
if (current != last) {
expr = parse_expression();
}
return std::make_unique<return_statement>(std::move(expr), start);
}
std::string parse_type(){ std::string parse_type(){
if (current->type != Identifier) { if (current->type != Identifier) {
fail_invalid_token(*current); fail_invalid_token(*current);
@ -551,6 +589,8 @@ namespace jlx {
return parse_if_statement(); return parse_if_statement();
} else if (current->content == "fun" && top_level) { } else if (current->content == "fun" && top_level) {
return parse_function(); return parse_function();
} else if (current->content == "return") {
return parse_return_statement();
} }
} }

View file

@ -80,13 +80,14 @@ namespace jlx {
export class tokenizer { export class tokenizer {
source_stream<char> source; source_stream<char> source;
static constexpr std::array<std::string, 5> keywords = {{ static constexpr std::array<std::string, 6> keywords = {{
"if", "if",
"else", "else",
"fun", "fun",
//"struct", //"struct",
"let", "let",
"var" "var",
"return"
}}; }};
static constexpr std::array<char, 7> punctuations = {{ static constexpr std::array<char, 7> punctuations = {{

View file

@ -1,5 +1,6 @@
module; module;
#include <ranges>
#include <stdexcept> #include <stdexcept>
#include <format> #include <format>
#include <string> #include <string>
@ -19,18 +20,20 @@ namespace jlx {
export class type_checker_runtime_context { export class type_checker_runtime_context {
public: public:
virtual std::string get_identifier_type(const std::string_view&) const = 0; virtual ~type_checker_runtime_context() = default;
virtual runtime_function get_function(const std::string_view&) const = 0;
virtual std::string_view get_identifier_type(const std::string_view&) const = 0;
virtual const runtime_function& get_function(const std::string_view&) const = 0;
}; };
export class type_error : std::runtime_error { export class type_error : public std::runtime_error {
public: public:
explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) { explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) {
} }
}; };
export class type_checker { class type_checker_utils {
static std::size_t size_of(const std::string_view& type) { static std::size_t size_of(const std::string_view& type) {
if (type == "u8" || type == "i8" || type == "char") { if (type == "u8" || type == "i8" || type == "char") {
return 1; return 1;
@ -57,7 +60,7 @@ namespace jlx {
return type == "f32" || type == "f64" || type == "comptime_float"; return type == "f32" || type == "f64" || type == "comptime_float";
} }
public:
static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) { static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) {
if (target_type == source_type) { if (target_type == source_type) {
return true; return true;
@ -145,10 +148,10 @@ namespace jlx {
return false; return false;
} }
static void evaluate_expression_type(std::unique_ptr<expression>& expr, const type_checker_runtime_context& ctx) { static void evaluate_expression_type(expression& expr, const type_checker_runtime_context& ctx) {
const auto& content = expr->t.content; const auto& content = expr.t.content;
if (expr->et == EtLiteralValue) { if (expr.et == EtLiteralValue) {
auto* literal_expression = dynamic_cast<literal_value*>(expr.get()); auto* literal_expression = dynamic_cast<literal_value*>(&expr);
if (literal_expression == nullptr) { if (literal_expression == nullptr) {
throw std::runtime_error("Internal type checker/AST error"); throw std::runtime_error("Internal type checker/AST error");
} }
@ -157,22 +160,25 @@ namespace jlx {
switch(literal_expression->t.type) { switch(literal_expression->t.type) {
case token_type::Number: case token_type::Number:
if (std::find(content.begin(), content.end(), '.') != content.end()) { if (std::find(content.begin(), content.end(), '.') != content.end()) {
type = "comptime.float"; type = "comptime_float";
} else { } else {
type = "comptime.int"; type = "comptime_int";
} }
break; break;
case token_type::String: case token_type::String:
type = "comptime.string"; type = "string";
break;
case token_type::Boolean:
type = "boolean";
break; break;
default: default:
throw std::runtime_error("Unsupported token for literal value expression"); throw std::runtime_error("Unsupported token for literal value expression");
} }
expr->evaluated_type = type; expr.evaluated_type = type;
return; return;
} else if (expr->et == EtSingleValueOperation) { } else if (expr.et == EtSingleValueOperation) {
auto* single_op = dynamic_cast<single_operation*>(expr.get()); auto* single_op = dynamic_cast<single_operation*>(&expr);
if (single_op == nullptr) { if (single_op == nullptr) {
throw std::runtime_error("Internal type checker/AST error"); throw std::runtime_error("Internal type checker/AST error");
@ -184,20 +190,20 @@ namespace jlx {
auto& opr = single_op->operand; auto& opr = single_op->operand;
if (opr->evaluated_type == std::nullopt) { if (opr->evaluated_type == std::nullopt) {
evaluate_expression_type(opr, ctx); evaluate_expression_type(*opr, ctx);
} }
auto& op = single_op->operator_token.content; auto& op = single_op->operator_token.content;
auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void"; auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void";
if (op == "+") { if (op == "+") {
if (type_is_integer(base_type) || type_is_float(base_type)) { if (type_is_integer(base_type) || type_is_float(base_type)) {
expr->evaluated_type = std::move(base_type); expr.evaluated_type = std::move(base_type);
return; return;
} }
throw type_error(opr->t, base_type); throw type_error(opr->t, base_type);
} else if(op == "-") { } else if(op == "-") {
if (type_is_integer(base_type) || type_is_float(base_type)) { if (type_is_integer(base_type) || type_is_float(base_type)) {
expr->evaluated_type = std::move(base_type); expr.evaluated_type = std::move(base_type);
//TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'? //TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'?
return; return;
} }
@ -207,13 +213,13 @@ namespace jlx {
throw type_error(opr->t, base_type); throw type_error(opr->t, base_type);
} }
expr->evaluated_type = "boolean"; expr.evaluated_type = "boolean";
return; return;
} else { } else {
throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation"); throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation");
} }
} else if (expr->et == EtDualValueOperation) { } else if (expr.et == EtDualValueOperation) {
auto* dual_op = dynamic_cast<dual_operation*>(expr.get()); auto* dual_op = dynamic_cast<dual_operation*>(&expr);
if (dual_op == nullptr) { if (dual_op == nullptr) {
throw std::runtime_error("Internal type checker/AST error"); throw std::runtime_error("Internal type checker/AST error");
@ -226,11 +232,11 @@ namespace jlx {
auto& expr0 = dual_op->first_operand; auto& expr0 = dual_op->first_operand;
auto& expr1 = dual_op->second_operand; auto& expr1 = dual_op->second_operand;
if (!expr0->evaluated_type.has_value()) { if (!expr0->evaluated_type.has_value()) {
evaluate_expression_type(expr0, ctx); evaluate_expression_type(*expr0, ctx);
} }
if (!expr1->evaluated_type.has_value()) { if (!expr1->evaluated_type.has_value()) {
evaluate_expression_type(expr1, ctx); evaluate_expression_type(*expr1, ctx);
} }
auto& op = dual_op->operator_token.content; auto& op = dual_op->operator_token.content;
@ -243,27 +249,27 @@ namespace jlx {
} }
if (type0 == "string") { if (type0 == "string") {
expr->evaluated_type = "string"; expr.evaluated_type = "string";
return; return;
} }
if (type0 == type1) { if (type0 == type1) {
expr->evaluated_type = std::move(type0); expr.evaluated_type = std::move(type0);
return; return;
} }
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
return; return;
} }
if (type_is_float(type0) && type1 == "comptime_int") { if (type_is_float(type0) && type1 == "comptime_int") {
expr->evaluated_type = type0; expr.evaluated_type = type0;
return; return;
} }
if (type_is_float(type1) && type0 == "comptime_int") { if (type_is_float(type1) && type0 == "comptime_int") {
expr->evaluated_type = type1; expr.evaluated_type = type1;
return; return;
} }
@ -274,25 +280,25 @@ namespace jlx {
} }
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) { if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1; expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
return; return;
} }
if (type_is_float(type0) && type1 == "comptime_int") { if (type_is_float(type0) && type1 == "comptime_int") {
expr->evaluated_type = type0; expr.evaluated_type = type0;
return; return;
} }
if (type_is_float(type1) && type0 == "comptime_int") { if (type_is_float(type1) && type0 == "comptime_int") {
expr->evaluated_type = type1; expr.evaluated_type = type1;
return; return;
} }
throw std::runtime_error("Unsupported type combination for operator '+'"); throw std::runtime_error("Unsupported type combination for operator '+'");
} else if (op == "==" || op == "!=") { } else if (op == "==" || op == "!=") {
if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) { if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) {
expr->evaluated_type = "boolean"; expr.evaluated_type = "boolean";
return; return;
} }
@ -310,11 +316,11 @@ namespace jlx {
throw type_error(expr1->t, type1); throw type_error(expr1->t, type1);
} }
expr->evaluated_type = "boolean"; expr.evaluated_type = "boolean";
return; return;
} }
} else if (expr->et == EtFunctionCall) { } else if (expr.et == EtFunctionCall) {
auto* call = dynamic_cast<function_call*>(expr.get()); auto* call = dynamic_cast<function_call*>(&expr);
if (call == nullptr) { if (call == nullptr) {
throw std::runtime_error("Internal type checker/AST error"); throw std::runtime_error("Internal type checker/AST error");
@ -332,7 +338,7 @@ namespace jlx {
auto& arg = call->arguments[i]; auto& arg = call->arguments[i];
if (!arg->evaluated_type.has_value()) { if (!arg->evaluated_type.has_value()) {
evaluate_expression_type(arg, ctx); evaluate_expression_type(*arg, ctx);
} }
auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void"; auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void";
@ -341,20 +347,222 @@ namespace jlx {
} }
} }
expr->evaluated_type = fn.return_type; expr.evaluated_type = fn.return_type;
return; return;
} else if (expr->et == EtIdentifier) { } else if (expr.et == EtIdentifier) {
auto* ident_expr = dynamic_cast<identifier_expression*>(expr.get()); auto* ident_expr = dynamic_cast<identifier_expression*>(&expr);
if (ident_expr == nullptr) { if (ident_expr == nullptr) {
throw std::runtime_error("Internal type checker/AST error"); throw std::runtime_error("Internal type checker/AST error");
} }
expr->evaluated_type = ctx.get_identifier_type(ident_expr->name); expr.evaluated_type = ctx.get_identifier_type(ident_expr->name);
return; return;
} }
throw std::runtime_error("Cannot type-check expression: Unknown expression"); throw std::runtime_error("Cannot type-check expression: Unknown expression");
} }
}; };
export template<class T>
concept statement_iterator = requires() {
requires std::same_as<typename T::value_type, std::unique_ptr<statement>>;
std::bidirectional_iterator<T>;
};
struct type_checker_scope {
std::unordered_map<std::string, std::string> variables;
std::unordered_map<std::string, runtime_function> functions;
std::optional<std::string> intended_return_type;
};
export template<statement_iterator T, std::sentinel_for<T> E>
class type_checker : public type_checker_runtime_context {
private:
T current;
E end;
std::vector<type_checker_scope> scopes;
void check_statement(statement& s, bool skip_block_scope = false) {
switch (s.type) {
case Root: {
auto* root = dynamic_cast<root_statement*>(&s);
for (auto& st : root->statements) {
check_statement(*st);
}
}
break;
case Block: {
auto* blk = dynamic_cast<block*>(&s);
if (!skip_block_scope) {
scopes.emplace_back(type_checker_scope{});
}
for (auto& st : blk->statements) {
check_statement(*st);
}
if (!skip_block_scope) {
scopes.pop_back();
}
}
break;
case IfStatement: {
auto* if_stmt = dynamic_cast<if_statement*>(&s);
auto& cond = if_stmt->condition;
if (!cond->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*if_stmt->condition, *this);
}
auto tp = cond->evaluated_type.has_value() ? cond->evaluated_type.value() : "void";
if (tp != "boolean") {
throw type_error(cond->t, tp);
}
check_statement(*if_stmt->block);
}
break;
case VariableDeclaration: {
auto* decl = dynamic_cast<variable_declaration*>(&s);
if (has_variable(decl->name)) {
throw std::runtime_error("Duplicate variable name detected");
}
std::optional<std::string> init_type = std::nullopt;
if (decl->initial_expression != nullptr) {
if (!decl->initial_expression->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*decl->initial_expression, *this);
}
init_type = decl->initial_expression->evaluated_type;
}
auto type = decl->type;
if (init_type.has_value()) {
if (!type.has_value()) {
type = init_type;
} else if (!type_checker_utils::is_convertible_to(init_type.value(), type.value())) {
throw type_error(decl->initial_expression->t, init_type.value());
}
}
if (!type.has_value()) {
throw std::runtime_error("Cannot infer variable declaration type");
}
scopes[scopes.size() - 1].variables.insert_or_assign(decl->name, type.value());
}
break;
case FunctionDeclaration: {
auto* decl = dynamic_cast<function_declaration*>(&s);
if (has_function(decl->name)) {
throw std::runtime_error("Duplicate function name detected");
}
if (!decl->return_type.has_value()) {
decl->return_type = "void";
}
runtime_function f;
f.return_type = decl->return_type.value();
for (auto& a : decl->parameters) {
f.arguments.emplace_back(a.type);
}
scopes[scopes.size() - 1].functions.insert_or_assign(decl->name, f);
auto& function_scope = scopes.emplace_back(type_checker_scope{});
function_scope.intended_return_type = decl->return_type;
check_statement(*decl->body, true);
scopes.pop_back();
}
break;
case Expression: {
auto* expr = dynamic_cast<expression*>(&s);
type_checker_utils::evaluate_expression_type(*expr, *this);
}
break;
case ReturnStatement: {
auto* rtr = dynamic_cast<return_statement*>(&s);
auto& scope = scopes.back();
if (rtr->expression != nullptr && !rtr->expression->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*rtr->expression, *this);
}
auto rvl = scope.intended_return_type.has_value() ? scope.intended_return_type.value() : "void";
auto exvl = rtr->expression != nullptr && rtr->expression->evaluated_type.has_value() ? rtr->expression->evaluated_type.value() : "void";
if (!type_checker_utils::is_convertible_to(exvl, rvl)) {
throw type_error(rtr->expression->t, exvl);
}
}
break;
}
}
[[nodiscard]] bool has_variable(const std::string_view& name) const {
auto& scope = scopes[scopes.size() - 1];
auto itx = scope.variables.find(std::string(name));
if (itx != scope.variables.end()) {
return true;
}
return false;
}
[[nodiscard]] bool has_function(const std::string_view& name) const {
for (const auto& scope : std::ranges::reverse_view(scopes)) {
auto itx = scope.functions.find(std::string(name));
if (itx != scope.functions.end()) {
return true;
}
}
return false;
}
public:
type_checker(T start, E end) : current(start), end(end) {
scopes.emplace_back();
}
void check() {
while (current != end) {
check_statement(**current);
++current;
}
}
[[nodiscard]] std::string_view get_identifier_type(const std::string_view &name) const override {
for (const auto & scope : std::ranges::reverse_view(scopes)) {
auto itx = scope.variables.find(std::string(name));
if (itx != scope.variables.end()) {
return itx->second;
}
}
throw std::runtime_error("Missing identifier");
}
[[nodiscard]] const runtime_function& get_function(const std::string_view &name) const override {
for (const auto& scope : std::ranges::reverse_view(scopes)) {
if (auto itx = scope.functions.find(std::string(name)); itx != scope.functions.end()) {
return itx->second;
}
}
throw std::runtime_error("Missing function");
}
};
} }