[libjlx] Implement basic type_checking

This commit is contained in:
John Stefanelli 2025-05-30 10:02:23 +02:00
parent 07d2a25712
commit e3d912c6a5
Signed by: jstefanelli
GPG key ID: 60EDE2437640D2AA
4 changed files with 304 additions and 45 deletions

View file

@ -63,5 +63,15 @@ int main(int argc, char** argv) {
std::cout << "Parsed " << rt->statements.size() << " statements" << std::endl;
auto type_checker = jlx::type_checker(rt->statements.begin(), rt->statements.end());
try {
type_checker.check();
} catch (jlx::type_error& err) {
std::cerr << "Type error: " << err.what() << std::endl;
return 1;
}
std::cout << "Type checking OK." << std::endl;
return 0;
}

View file

@ -17,8 +17,8 @@ namespace jlx {
Block,
FunctionDeclaration,
VariableDeclaration,
LiteralValue,
IfStatement,
ReturnStatement,
};
export enum expression_type {
@ -52,7 +52,7 @@ namespace jlx {
}
explicit root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) {
root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) {
}
@ -77,11 +77,25 @@ namespace jlx {
expression_type et = EtInvalid;
std::optional<std::string> evaluated_type;
virtual std::unique_ptr<expression> clone() const = 0;
[[nodiscard]] virtual std::unique_ptr<expression> clone() const = 0;
~expression() override = default;
};
export struct return_statement : public statement {
explicit return_statement(const token& t) : statement(ReturnStatement, t) {
}
return_statement(std::unique_ptr<expression> expression, const token& t) : statement(ReturnStatement, t), expression(std::move(expression)) {
}
std::unique_ptr<expression> expression;
~return_statement() override = default;
};
export struct block : public statement {
explicit block(const token& t) : statement(Block, t) {
@ -455,6 +469,13 @@ namespace jlx {
}
}
fail_invalid_token(*current);
break;
case Keyword:
if (previous != nullptr) {
return previous;
}
fail_invalid_token(*current);
break;
default:
fail_invalid_token(*current);
}
@ -531,6 +552,23 @@ namespace jlx {
return std::make_unique<function_declaration>(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block));
}
std::unique_ptr<return_statement> parse_return_statement() {
auto start = *current;
if (current->type != Keyword || current->content != "return") {
fail_invalid_token(*current);
}
next(false);
std::unique_ptr<expression> expr = nullptr;
if (current != last) {
expr = parse_expression();
}
return std::make_unique<return_statement>(std::move(expr), start);
}
std::string parse_type(){
if (current->type != Identifier) {
fail_invalid_token(*current);
@ -551,6 +589,8 @@ namespace jlx {
return parse_if_statement();
} else if (current->content == "fun" && top_level) {
return parse_function();
} else if (current->content == "return") {
return parse_return_statement();
}
}

View file

@ -80,13 +80,14 @@ namespace jlx {
export class tokenizer {
source_stream<char> source;
static constexpr std::array<std::string, 5> keywords = {{
static constexpr std::array<std::string, 6> keywords = {{
"if",
"else",
"fun",
//"struct",
"let",
"var"
"var",
"return"
}};
static constexpr std::array<char, 7> punctuations = {{

View file

@ -1,5 +1,6 @@
module;
#include <ranges>
#include <stdexcept>
#include <format>
#include <string>
@ -19,18 +20,20 @@ namespace jlx {
export class type_checker_runtime_context {
public:
virtual std::string get_identifier_type(const std::string_view&) const = 0;
virtual runtime_function get_function(const std::string_view&) const = 0;
virtual ~type_checker_runtime_context() = default;
virtual std::string_view get_identifier_type(const std::string_view&) const = 0;
virtual const runtime_function& get_function(const std::string_view&) const = 0;
};
export class type_error : std::runtime_error {
export class type_error : public std::runtime_error {
public:
explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) {
}
};
export class type_checker {
class type_checker_utils {
static std::size_t size_of(const std::string_view& type) {
if (type == "u8" || type == "i8" || type == "char") {
return 1;
@ -57,7 +60,7 @@ namespace jlx {
return type == "f32" || type == "f64" || type == "comptime_float";
}
public:
static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) {
if (target_type == source_type) {
return true;
@ -145,10 +148,10 @@ namespace jlx {
return false;
}
static void evaluate_expression_type(std::unique_ptr<expression>& expr, const type_checker_runtime_context& ctx) {
const auto& content = expr->t.content;
if (expr->et == EtLiteralValue) {
auto* literal_expression = dynamic_cast<literal_value*>(expr.get());
static void evaluate_expression_type(expression& expr, const type_checker_runtime_context& ctx) {
const auto& content = expr.t.content;
if (expr.et == EtLiteralValue) {
auto* literal_expression = dynamic_cast<literal_value*>(&expr);
if (literal_expression == nullptr) {
throw std::runtime_error("Internal type checker/AST error");
}
@ -157,22 +160,25 @@ namespace jlx {
switch(literal_expression->t.type) {
case token_type::Number:
if (std::find(content.begin(), content.end(), '.') != content.end()) {
type = "comptime.float";
type = "comptime_float";
} else {
type = "comptime.int";
type = "comptime_int";
}
break;
case token_type::String:
type = "comptime.string";
type = "string";
break;
case token_type::Boolean:
type = "boolean";
break;
default:
throw std::runtime_error("Unsupported token for literal value expression");
}
expr->evaluated_type = type;
expr.evaluated_type = type;
return;
} else if (expr->et == EtSingleValueOperation) {
auto* single_op = dynamic_cast<single_operation*>(expr.get());
} else if (expr.et == EtSingleValueOperation) {
auto* single_op = dynamic_cast<single_operation*>(&expr);
if (single_op == nullptr) {
throw std::runtime_error("Internal type checker/AST error");
@ -184,20 +190,20 @@ namespace jlx {
auto& opr = single_op->operand;
if (opr->evaluated_type == std::nullopt) {
evaluate_expression_type(opr, ctx);
evaluate_expression_type(*opr, ctx);
}
auto& op = single_op->operator_token.content;
auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void";
if (op == "+") {
if (type_is_integer(base_type) || type_is_float(base_type)) {
expr->evaluated_type = std::move(base_type);
expr.evaluated_type = std::move(base_type);
return;
}
throw type_error(opr->t, base_type);
} else if(op == "-") {
if (type_is_integer(base_type) || type_is_float(base_type)) {
expr->evaluated_type = std::move(base_type);
expr.evaluated_type = std::move(base_type);
//TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'?
return;
}
@ -207,13 +213,13 @@ namespace jlx {
throw type_error(opr->t, base_type);
}
expr->evaluated_type = "boolean";
expr.evaluated_type = "boolean";
return;
} else {
throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation");
}
} else if (expr->et == EtDualValueOperation) {
auto* dual_op = dynamic_cast<dual_operation*>(expr.get());
} else if (expr.et == EtDualValueOperation) {
auto* dual_op = dynamic_cast<dual_operation*>(&expr);
if (dual_op == nullptr) {
throw std::runtime_error("Internal type checker/AST error");
@ -226,11 +232,11 @@ namespace jlx {
auto& expr0 = dual_op->first_operand;
auto& expr1 = dual_op->second_operand;
if (!expr0->evaluated_type.has_value()) {
evaluate_expression_type(expr0, ctx);
evaluate_expression_type(*expr0, ctx);
}
if (!expr1->evaluated_type.has_value()) {
evaluate_expression_type(expr1, ctx);
evaluate_expression_type(*expr1, ctx);
}
auto& op = dual_op->operator_token.content;
@ -243,27 +249,27 @@ namespace jlx {
}
if (type0 == "string") {
expr->evaluated_type = "string";
expr.evaluated_type = "string";
return;
}
if (type0 == type1) {
expr->evaluated_type = std::move(type0);
expr.evaluated_type = std::move(type0);
return;
}
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
return;
}
if (type_is_float(type0) && type1 == "comptime_int") {
expr->evaluated_type = type0;
expr.evaluated_type = type0;
return;
}
if (type_is_float(type1) && type0 == "comptime_int") {
expr->evaluated_type = type1;
expr.evaluated_type = type1;
return;
}
@ -274,25 +280,25 @@ namespace jlx {
}
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
return;
}
if (type_is_float(type0) && type1 == "comptime_int") {
expr->evaluated_type = type0;
expr.evaluated_type = type0;
return;
}
if (type_is_float(type1) && type0 == "comptime_int") {
expr->evaluated_type = type1;
expr.evaluated_type = type1;
return;
}
throw std::runtime_error("Unsupported type combination for operator '+'");
} else if (op == "==" || op == "!=") {
if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) {
expr->evaluated_type = "boolean";
expr.evaluated_type = "boolean";
return;
}
@ -310,11 +316,11 @@ namespace jlx {
throw type_error(expr1->t, type1);
}
expr->evaluated_type = "boolean";
expr.evaluated_type = "boolean";
return;
}
} else if (expr->et == EtFunctionCall) {
auto* call = dynamic_cast<function_call*>(expr.get());
} else if (expr.et == EtFunctionCall) {
auto* call = dynamic_cast<function_call*>(&expr);
if (call == nullptr) {
throw std::runtime_error("Internal type checker/AST error");
@ -332,7 +338,7 @@ namespace jlx {
auto& arg = call->arguments[i];
if (!arg->evaluated_type.has_value()) {
evaluate_expression_type(arg, ctx);
evaluate_expression_type(*arg, ctx);
}
auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void";
@ -341,20 +347,222 @@ namespace jlx {
}
}
expr->evaluated_type = fn.return_type;
expr.evaluated_type = fn.return_type;
return;
} else if (expr->et == EtIdentifier) {
auto* ident_expr = dynamic_cast<identifier_expression*>(expr.get());
} else if (expr.et == EtIdentifier) {
auto* ident_expr = dynamic_cast<identifier_expression*>(&expr);
if (ident_expr == nullptr) {
throw std::runtime_error("Internal type checker/AST error");
}
expr->evaluated_type = ctx.get_identifier_type(ident_expr->name);
expr.evaluated_type = ctx.get_identifier_type(ident_expr->name);
return;
}
throw std::runtime_error("Cannot type-check expression: Unknown expression");
}
};
export template<class T>
concept statement_iterator = requires() {
requires std::same_as<typename T::value_type, std::unique_ptr<statement>>;
std::bidirectional_iterator<T>;
};
struct type_checker_scope {
std::unordered_map<std::string, std::string> variables;
std::unordered_map<std::string, runtime_function> functions;
std::optional<std::string> intended_return_type;
};
export template<statement_iterator T, std::sentinel_for<T> E>
class type_checker : public type_checker_runtime_context {
private:
T current;
E end;
std::vector<type_checker_scope> scopes;
void check_statement(statement& s, bool skip_block_scope = false) {
switch (s.type) {
case Root: {
auto* root = dynamic_cast<root_statement*>(&s);
for (auto& st : root->statements) {
check_statement(*st);
}
}
break;
case Block: {
auto* blk = dynamic_cast<block*>(&s);
if (!skip_block_scope) {
scopes.emplace_back(type_checker_scope{});
}
for (auto& st : blk->statements) {
check_statement(*st);
}
if (!skip_block_scope) {
scopes.pop_back();
}
}
break;
case IfStatement: {
auto* if_stmt = dynamic_cast<if_statement*>(&s);
auto& cond = if_stmt->condition;
if (!cond->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*if_stmt->condition, *this);
}
auto tp = cond->evaluated_type.has_value() ? cond->evaluated_type.value() : "void";
if (tp != "boolean") {
throw type_error(cond->t, tp);
}
check_statement(*if_stmt->block);
}
break;
case VariableDeclaration: {
auto* decl = dynamic_cast<variable_declaration*>(&s);
if (has_variable(decl->name)) {
throw std::runtime_error("Duplicate variable name detected");
}
std::optional<std::string> init_type = std::nullopt;
if (decl->initial_expression != nullptr) {
if (!decl->initial_expression->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*decl->initial_expression, *this);
}
init_type = decl->initial_expression->evaluated_type;
}
auto type = decl->type;
if (init_type.has_value()) {
if (!type.has_value()) {
type = init_type;
} else if (!type_checker_utils::is_convertible_to(init_type.value(), type.value())) {
throw type_error(decl->initial_expression->t, init_type.value());
}
}
if (!type.has_value()) {
throw std::runtime_error("Cannot infer variable declaration type");
}
scopes[scopes.size() - 1].variables.insert_or_assign(decl->name, type.value());
}
break;
case FunctionDeclaration: {
auto* decl = dynamic_cast<function_declaration*>(&s);
if (has_function(decl->name)) {
throw std::runtime_error("Duplicate function name detected");
}
if (!decl->return_type.has_value()) {
decl->return_type = "void";
}
runtime_function f;
f.return_type = decl->return_type.value();
for (auto& a : decl->parameters) {
f.arguments.emplace_back(a.type);
}
scopes[scopes.size() - 1].functions.insert_or_assign(decl->name, f);
auto& function_scope = scopes.emplace_back(type_checker_scope{});
function_scope.intended_return_type = decl->return_type;
check_statement(*decl->body, true);
scopes.pop_back();
}
break;
case Expression: {
auto* expr = dynamic_cast<expression*>(&s);
type_checker_utils::evaluate_expression_type(*expr, *this);
}
break;
case ReturnStatement: {
auto* rtr = dynamic_cast<return_statement*>(&s);
auto& scope = scopes.back();
if (rtr->expression != nullptr && !rtr->expression->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*rtr->expression, *this);
}
auto rvl = scope.intended_return_type.has_value() ? scope.intended_return_type.value() : "void";
auto exvl = rtr->expression != nullptr && rtr->expression->evaluated_type.has_value() ? rtr->expression->evaluated_type.value() : "void";
if (!type_checker_utils::is_convertible_to(exvl, rvl)) {
throw type_error(rtr->expression->t, exvl);
}
}
break;
}
}
[[nodiscard]] bool has_variable(const std::string_view& name) const {
auto& scope = scopes[scopes.size() - 1];
auto itx = scope.variables.find(std::string(name));
if (itx != scope.variables.end()) {
return true;
}
return false;
}
[[nodiscard]] bool has_function(const std::string_view& name) const {
for (const auto& scope : std::ranges::reverse_view(scopes)) {
auto itx = scope.functions.find(std::string(name));
if (itx != scope.functions.end()) {
return true;
}
}
return false;
}
public:
type_checker(T start, E end) : current(start), end(end) {
scopes.emplace_back();
}
void check() {
while (current != end) {
check_statement(**current);
++current;
}
}
[[nodiscard]] std::string_view get_identifier_type(const std::string_view &name) const override {
for (const auto & scope : std::ranges::reverse_view(scopes)) {
auto itx = scope.variables.find(std::string(name));
if (itx != scope.variables.end()) {
return itx->second;
}
}
throw std::runtime_error("Missing identifier");
}
[[nodiscard]] const runtime_function& get_function(const std::string_view &name) const override {
for (const auto& scope : std::ranges::reverse_view(scopes)) {
if (auto itx = scope.functions.find(std::string(name)); itx != scope.functions.end()) {
return itx->second;
}
}
throw std::runtime_error("Missing function");
}
};
}