[libjlx] Implement basic type_checking
This commit is contained in:
parent
07d2a25712
commit
e3d912c6a5
4 changed files with 304 additions and 45 deletions
|
|
@ -63,5 +63,15 @@ int main(int argc, char** argv) {
|
|||
|
||||
std::cout << "Parsed " << rt->statements.size() << " statements" << std::endl;
|
||||
|
||||
auto type_checker = jlx::type_checker(rt->statements.begin(), rt->statements.end());
|
||||
try {
|
||||
type_checker.check();
|
||||
} catch (jlx::type_error& err) {
|
||||
std::cerr << "Type error: " << err.what() << std::endl;
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::cout << "Type checking OK." << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ namespace jlx {
|
|||
Block,
|
||||
FunctionDeclaration,
|
||||
VariableDeclaration,
|
||||
LiteralValue,
|
||||
IfStatement,
|
||||
ReturnStatement,
|
||||
};
|
||||
|
||||
export enum expression_type {
|
||||
|
|
@ -52,7 +52,7 @@ namespace jlx {
|
|||
|
||||
}
|
||||
|
||||
explicit root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) {
|
||||
root_statement(std::vector<std::unique_ptr<statement>> statements, const token& t) : statement(Root, t), statements(std::move(statements)) {
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -77,11 +77,25 @@ namespace jlx {
|
|||
expression_type et = EtInvalid;
|
||||
std::optional<std::string> evaluated_type;
|
||||
|
||||
virtual std::unique_ptr<expression> clone() const = 0;
|
||||
[[nodiscard]] virtual std::unique_ptr<expression> clone() const = 0;
|
||||
|
||||
~expression() override = default;
|
||||
};
|
||||
|
||||
export struct return_statement : public statement {
|
||||
explicit return_statement(const token& t) : statement(ReturnStatement, t) {
|
||||
|
||||
}
|
||||
|
||||
return_statement(std::unique_ptr<expression> expression, const token& t) : statement(ReturnStatement, t), expression(std::move(expression)) {
|
||||
|
||||
}
|
||||
|
||||
std::unique_ptr<expression> expression;
|
||||
|
||||
~return_statement() override = default;
|
||||
};
|
||||
|
||||
export struct block : public statement {
|
||||
explicit block(const token& t) : statement(Block, t) {
|
||||
|
||||
|
|
@ -455,6 +469,13 @@ namespace jlx {
|
|||
}
|
||||
}
|
||||
fail_invalid_token(*current);
|
||||
break;
|
||||
case Keyword:
|
||||
if (previous != nullptr) {
|
||||
return previous;
|
||||
}
|
||||
fail_invalid_token(*current);
|
||||
break;
|
||||
default:
|
||||
fail_invalid_token(*current);
|
||||
}
|
||||
|
|
@ -531,6 +552,23 @@ namespace jlx {
|
|||
return std::make_unique<function_declaration>(start, std::move(function_name), std::move(params), std::move(return_type), std::move(block));
|
||||
}
|
||||
|
||||
std::unique_ptr<return_statement> parse_return_statement() {
|
||||
auto start = *current;
|
||||
|
||||
if (current->type != Keyword || current->content != "return") {
|
||||
fail_invalid_token(*current);
|
||||
}
|
||||
|
||||
next(false);
|
||||
|
||||
std::unique_ptr<expression> expr = nullptr;
|
||||
if (current != last) {
|
||||
expr = parse_expression();
|
||||
}
|
||||
|
||||
return std::make_unique<return_statement>(std::move(expr), start);
|
||||
}
|
||||
|
||||
std::string parse_type(){
|
||||
if (current->type != Identifier) {
|
||||
fail_invalid_token(*current);
|
||||
|
|
@ -551,6 +589,8 @@ namespace jlx {
|
|||
return parse_if_statement();
|
||||
} else if (current->content == "fun" && top_level) {
|
||||
return parse_function();
|
||||
} else if (current->content == "return") {
|
||||
return parse_return_statement();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -80,13 +80,14 @@ namespace jlx {
|
|||
export class tokenizer {
|
||||
source_stream<char> source;
|
||||
|
||||
static constexpr std::array<std::string, 5> keywords = {{
|
||||
static constexpr std::array<std::string, 6> keywords = {{
|
||||
"if",
|
||||
"else",
|
||||
"fun",
|
||||
//"struct",
|
||||
"let",
|
||||
"var"
|
||||
"var",
|
||||
"return"
|
||||
}};
|
||||
|
||||
static constexpr std::array<char, 7> punctuations = {{
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
module;
|
||||
|
||||
#include <ranges>
|
||||
#include <stdexcept>
|
||||
#include <format>
|
||||
#include <string>
|
||||
|
|
@ -19,18 +20,20 @@ namespace jlx {
|
|||
|
||||
export class type_checker_runtime_context {
|
||||
public:
|
||||
virtual std::string get_identifier_type(const std::string_view&) const = 0;
|
||||
virtual runtime_function get_function(const std::string_view&) const = 0;
|
||||
virtual ~type_checker_runtime_context() = default;
|
||||
|
||||
virtual std::string_view get_identifier_type(const std::string_view&) const = 0;
|
||||
virtual const runtime_function& get_function(const std::string_view&) const = 0;
|
||||
};
|
||||
|
||||
export class type_error : std::runtime_error {
|
||||
export class type_error : public std::runtime_error {
|
||||
public:
|
||||
explicit type_error(const token& t, const std::string_view& provided_type) : std::runtime_error(std::format("Type {} is invalid at {}:{}:{}", provided_type, t.source_file, t.line, t.col)) {
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
export class type_checker {
|
||||
class type_checker_utils {
|
||||
static std::size_t size_of(const std::string_view& type) {
|
||||
if (type == "u8" || type == "i8" || type == "char") {
|
||||
return 1;
|
||||
|
|
@ -57,7 +60,7 @@ namespace jlx {
|
|||
return type == "f32" || type == "f64" || type == "comptime_float";
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
static bool is_convertible_to(const std::string_view& source_type, const std::string_view& target_type) {
|
||||
if (target_type == source_type) {
|
||||
return true;
|
||||
|
|
@ -145,10 +148,10 @@ namespace jlx {
|
|||
return false;
|
||||
}
|
||||
|
||||
static void evaluate_expression_type(std::unique_ptr<expression>& expr, const type_checker_runtime_context& ctx) {
|
||||
const auto& content = expr->t.content;
|
||||
if (expr->et == EtLiteralValue) {
|
||||
auto* literal_expression = dynamic_cast<literal_value*>(expr.get());
|
||||
static void evaluate_expression_type(expression& expr, const type_checker_runtime_context& ctx) {
|
||||
const auto& content = expr.t.content;
|
||||
if (expr.et == EtLiteralValue) {
|
||||
auto* literal_expression = dynamic_cast<literal_value*>(&expr);
|
||||
if (literal_expression == nullptr) {
|
||||
throw std::runtime_error("Internal type checker/AST error");
|
||||
}
|
||||
|
|
@ -157,22 +160,25 @@ namespace jlx {
|
|||
switch(literal_expression->t.type) {
|
||||
case token_type::Number:
|
||||
if (std::find(content.begin(), content.end(), '.') != content.end()) {
|
||||
type = "comptime.float";
|
||||
type = "comptime_float";
|
||||
} else {
|
||||
type = "comptime.int";
|
||||
type = "comptime_int";
|
||||
}
|
||||
break;
|
||||
case token_type::String:
|
||||
type = "comptime.string";
|
||||
type = "string";
|
||||
break;
|
||||
case token_type::Boolean:
|
||||
type = "boolean";
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error("Unsupported token for literal value expression");
|
||||
}
|
||||
|
||||
expr->evaluated_type = type;
|
||||
expr.evaluated_type = type;
|
||||
return;
|
||||
} else if (expr->et == EtSingleValueOperation) {
|
||||
auto* single_op = dynamic_cast<single_operation*>(expr.get());
|
||||
} else if (expr.et == EtSingleValueOperation) {
|
||||
auto* single_op = dynamic_cast<single_operation*>(&expr);
|
||||
|
||||
if (single_op == nullptr) {
|
||||
throw std::runtime_error("Internal type checker/AST error");
|
||||
|
|
@ -184,20 +190,20 @@ namespace jlx {
|
|||
|
||||
auto& opr = single_op->operand;
|
||||
if (opr->evaluated_type == std::nullopt) {
|
||||
evaluate_expression_type(opr, ctx);
|
||||
evaluate_expression_type(*opr, ctx);
|
||||
}
|
||||
|
||||
auto& op = single_op->operator_token.content;
|
||||
auto base_type = opr->evaluated_type.has_value() ? opr->evaluated_type.value() : "void";
|
||||
if (op == "+") {
|
||||
if (type_is_integer(base_type) || type_is_float(base_type)) {
|
||||
expr->evaluated_type = std::move(base_type);
|
||||
expr.evaluated_type = std::move(base_type);
|
||||
return;
|
||||
}
|
||||
throw type_error(opr->t, base_type);
|
||||
} else if(op == "-") {
|
||||
if (type_is_integer(base_type) || type_is_float(base_type)) {
|
||||
expr->evaluated_type = std::move(base_type);
|
||||
expr.evaluated_type = std::move(base_type);
|
||||
//TODO: Maybe run immediately for 'comptime_int' and 'comptime_float'?
|
||||
return;
|
||||
}
|
||||
|
|
@ -207,13 +213,13 @@ namespace jlx {
|
|||
throw type_error(opr->t, base_type);
|
||||
}
|
||||
|
||||
expr->evaluated_type = "boolean";
|
||||
expr.evaluated_type = "boolean";
|
||||
return;
|
||||
} else {
|
||||
throw std::runtime_error("Cannot type-check expression: invalid operator for single-value operation");
|
||||
}
|
||||
} else if (expr->et == EtDualValueOperation) {
|
||||
auto* dual_op = dynamic_cast<dual_operation*>(expr.get());
|
||||
} else if (expr.et == EtDualValueOperation) {
|
||||
auto* dual_op = dynamic_cast<dual_operation*>(&expr);
|
||||
|
||||
if (dual_op == nullptr) {
|
||||
throw std::runtime_error("Internal type checker/AST error");
|
||||
|
|
@ -226,11 +232,11 @@ namespace jlx {
|
|||
auto& expr0 = dual_op->first_operand;
|
||||
auto& expr1 = dual_op->second_operand;
|
||||
if (!expr0->evaluated_type.has_value()) {
|
||||
evaluate_expression_type(expr0, ctx);
|
||||
evaluate_expression_type(*expr0, ctx);
|
||||
}
|
||||
|
||||
if (!expr1->evaluated_type.has_value()) {
|
||||
evaluate_expression_type(expr1, ctx);
|
||||
evaluate_expression_type(*expr1, ctx);
|
||||
}
|
||||
|
||||
auto& op = dual_op->operator_token.content;
|
||||
|
|
@ -243,27 +249,27 @@ namespace jlx {
|
|||
}
|
||||
|
||||
if (type0 == "string") {
|
||||
expr->evaluated_type = "string";
|
||||
expr.evaluated_type = "string";
|
||||
return;
|
||||
}
|
||||
|
||||
if (type0 == type1) {
|
||||
expr->evaluated_type = std::move(type0);
|
||||
expr.evaluated_type = std::move(type0);
|
||||
return;
|
||||
}
|
||||
|
||||
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
|
||||
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
|
||||
expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
|
||||
return;
|
||||
}
|
||||
|
||||
if (type_is_float(type0) && type1 == "comptime_int") {
|
||||
expr->evaluated_type = type0;
|
||||
expr.evaluated_type = type0;
|
||||
return;
|
||||
}
|
||||
|
||||
if (type_is_float(type1) && type0 == "comptime_int") {
|
||||
expr->evaluated_type = type1;
|
||||
expr.evaluated_type = type1;
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -274,25 +280,25 @@ namespace jlx {
|
|||
}
|
||||
|
||||
if ((type_is_integer(type0) && type_is_integer(type1)) || (type_is_float(type0) && type_is_float(type1))) {
|
||||
expr->evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
|
||||
expr.evaluated_type = size_of(type0) > size_of(type1) ? type0 : type1;
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
if (type_is_float(type0) && type1 == "comptime_int") {
|
||||
expr->evaluated_type = type0;
|
||||
expr.evaluated_type = type0;
|
||||
return;
|
||||
}
|
||||
|
||||
if (type_is_float(type1) && type0 == "comptime_int") {
|
||||
expr->evaluated_type = type1;
|
||||
expr.evaluated_type = type1;
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Unsupported type combination for operator '+'");
|
||||
} else if (op == "==" || op == "!=") {
|
||||
if (is_convertible_to(type1, type0) || is_convertible_to(type0, type1)) {
|
||||
expr->evaluated_type = "boolean";
|
||||
expr.evaluated_type = "boolean";
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -310,11 +316,11 @@ namespace jlx {
|
|||
throw type_error(expr1->t, type1);
|
||||
}
|
||||
|
||||
expr->evaluated_type = "boolean";
|
||||
expr.evaluated_type = "boolean";
|
||||
return;
|
||||
}
|
||||
} else if (expr->et == EtFunctionCall) {
|
||||
auto* call = dynamic_cast<function_call*>(expr.get());
|
||||
} else if (expr.et == EtFunctionCall) {
|
||||
auto* call = dynamic_cast<function_call*>(&expr);
|
||||
|
||||
if (call == nullptr) {
|
||||
throw std::runtime_error("Internal type checker/AST error");
|
||||
|
|
@ -332,7 +338,7 @@ namespace jlx {
|
|||
auto& arg = call->arguments[i];
|
||||
|
||||
if (!arg->evaluated_type.has_value()) {
|
||||
evaluate_expression_type(arg, ctx);
|
||||
evaluate_expression_type(*arg, ctx);
|
||||
}
|
||||
|
||||
auto src = arg->evaluated_type.has_value() ? arg->evaluated_type.value() : "void";
|
||||
|
|
@ -341,20 +347,222 @@ namespace jlx {
|
|||
}
|
||||
}
|
||||
|
||||
expr->evaluated_type = fn.return_type;
|
||||
expr.evaluated_type = fn.return_type;
|
||||
return;
|
||||
} else if (expr->et == EtIdentifier) {
|
||||
auto* ident_expr = dynamic_cast<identifier_expression*>(expr.get());
|
||||
} else if (expr.et == EtIdentifier) {
|
||||
auto* ident_expr = dynamic_cast<identifier_expression*>(&expr);
|
||||
|
||||
if (ident_expr == nullptr) {
|
||||
throw std::runtime_error("Internal type checker/AST error");
|
||||
}
|
||||
|
||||
expr->evaluated_type = ctx.get_identifier_type(ident_expr->name);
|
||||
expr.evaluated_type = ctx.get_identifier_type(ident_expr->name);
|
||||
return;
|
||||
}
|
||||
|
||||
throw std::runtime_error("Cannot type-check expression: Unknown expression");
|
||||
}
|
||||
};
|
||||
export template<class T>
|
||||
concept statement_iterator = requires() {
|
||||
requires std::same_as<typename T::value_type, std::unique_ptr<statement>>;
|
||||
std::bidirectional_iterator<T>;
|
||||
};
|
||||
|
||||
struct type_checker_scope {
|
||||
std::unordered_map<std::string, std::string> variables;
|
||||
std::unordered_map<std::string, runtime_function> functions;
|
||||
std::optional<std::string> intended_return_type;
|
||||
};
|
||||
|
||||
export template<statement_iterator T, std::sentinel_for<T> E>
|
||||
class type_checker : public type_checker_runtime_context {
|
||||
private:
|
||||
T current;
|
||||
E end;
|
||||
std::vector<type_checker_scope> scopes;
|
||||
|
||||
void check_statement(statement& s, bool skip_block_scope = false) {
|
||||
switch (s.type) {
|
||||
case Root: {
|
||||
auto* root = dynamic_cast<root_statement*>(&s);
|
||||
|
||||
for (auto& st : root->statements) {
|
||||
check_statement(*st);
|
||||
}
|
||||
}
|
||||
break;
|
||||
case Block: {
|
||||
auto* blk = dynamic_cast<block*>(&s);
|
||||
|
||||
if (!skip_block_scope) {
|
||||
scopes.emplace_back(type_checker_scope{});
|
||||
}
|
||||
|
||||
for (auto& st : blk->statements) {
|
||||
check_statement(*st);
|
||||
}
|
||||
|
||||
if (!skip_block_scope) {
|
||||
scopes.pop_back();
|
||||
}
|
||||
}
|
||||
break;
|
||||
case IfStatement: {
|
||||
auto* if_stmt = dynamic_cast<if_statement*>(&s);
|
||||
|
||||
auto& cond = if_stmt->condition;
|
||||
if (!cond->evaluated_type.has_value()) {
|
||||
type_checker_utils::evaluate_expression_type(*if_stmt->condition, *this);
|
||||
}
|
||||
|
||||
auto tp = cond->evaluated_type.has_value() ? cond->evaluated_type.value() : "void";
|
||||
|
||||
if (tp != "boolean") {
|
||||
throw type_error(cond->t, tp);
|
||||
}
|
||||
|
||||
check_statement(*if_stmt->block);
|
||||
}
|
||||
break;
|
||||
case VariableDeclaration: {
|
||||
auto* decl = dynamic_cast<variable_declaration*>(&s);
|
||||
|
||||
if (has_variable(decl->name)) {
|
||||
throw std::runtime_error("Duplicate variable name detected");
|
||||
}
|
||||
|
||||
std::optional<std::string> init_type = std::nullopt;
|
||||
if (decl->initial_expression != nullptr) {
|
||||
if (!decl->initial_expression->evaluated_type.has_value()) {
|
||||
type_checker_utils::evaluate_expression_type(*decl->initial_expression, *this);
|
||||
}
|
||||
init_type = decl->initial_expression->evaluated_type;
|
||||
}
|
||||
|
||||
auto type = decl->type;
|
||||
|
||||
if (init_type.has_value()) {
|
||||
if (!type.has_value()) {
|
||||
type = init_type;
|
||||
} else if (!type_checker_utils::is_convertible_to(init_type.value(), type.value())) {
|
||||
throw type_error(decl->initial_expression->t, init_type.value());
|
||||
}
|
||||
}
|
||||
|
||||
if (!type.has_value()) {
|
||||
throw std::runtime_error("Cannot infer variable declaration type");
|
||||
}
|
||||
|
||||
scopes[scopes.size() - 1].variables.insert_or_assign(decl->name, type.value());
|
||||
}
|
||||
break;
|
||||
case FunctionDeclaration: {
|
||||
auto* decl = dynamic_cast<function_declaration*>(&s);
|
||||
|
||||
if (has_function(decl->name)) {
|
||||
throw std::runtime_error("Duplicate function name detected");
|
||||
}
|
||||
|
||||
if (!decl->return_type.has_value()) {
|
||||
decl->return_type = "void";
|
||||
}
|
||||
|
||||
runtime_function f;
|
||||
f.return_type = decl->return_type.value();
|
||||
|
||||
for (auto& a : decl->parameters) {
|
||||
f.arguments.emplace_back(a.type);
|
||||
}
|
||||
|
||||
scopes[scopes.size() - 1].functions.insert_or_assign(decl->name, f);
|
||||
|
||||
auto& function_scope = scopes.emplace_back(type_checker_scope{});
|
||||
function_scope.intended_return_type = decl->return_type;
|
||||
|
||||
check_statement(*decl->body, true);
|
||||
|
||||
scopes.pop_back();
|
||||
}
|
||||
break;
|
||||
case Expression: {
|
||||
auto* expr = dynamic_cast<expression*>(&s);
|
||||
|
||||
type_checker_utils::evaluate_expression_type(*expr, *this);
|
||||
}
|
||||
break;
|
||||
case ReturnStatement: {
|
||||
auto* rtr = dynamic_cast<return_statement*>(&s);
|
||||
|
||||
auto& scope = scopes.back();
|
||||
|
||||
if (rtr->expression != nullptr && !rtr->expression->evaluated_type.has_value()) {
|
||||
type_checker_utils::evaluate_expression_type(*rtr->expression, *this);
|
||||
}
|
||||
|
||||
auto rvl = scope.intended_return_type.has_value() ? scope.intended_return_type.value() : "void";
|
||||
auto exvl = rtr->expression != nullptr && rtr->expression->evaluated_type.has_value() ? rtr->expression->evaluated_type.value() : "void";
|
||||
|
||||
if (!type_checker_utils::is_convertible_to(exvl, rvl)) {
|
||||
throw type_error(rtr->expression->t, exvl);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool has_variable(const std::string_view& name) const {
|
||||
auto& scope = scopes[scopes.size() - 1];
|
||||
auto itx = scope.variables.find(std::string(name));
|
||||
if (itx != scope.variables.end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool has_function(const std::string_view& name) const {
|
||||
for (const auto& scope : std::ranges::reverse_view(scopes)) {
|
||||
auto itx = scope.functions.find(std::string(name));
|
||||
if (itx != scope.functions.end()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
public:
|
||||
type_checker(T start, E end) : current(start), end(end) {
|
||||
scopes.emplace_back();
|
||||
}
|
||||
|
||||
void check() {
|
||||
while (current != end) {
|
||||
check_statement(**current);
|
||||
|
||||
++current;
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string_view get_identifier_type(const std::string_view &name) const override {
|
||||
for (const auto & scope : std::ranges::reverse_view(scopes)) {
|
||||
auto itx = scope.variables.find(std::string(name));
|
||||
if (itx != scope.variables.end()) {
|
||||
return itx->second;
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Missing identifier");
|
||||
}
|
||||
|
||||
[[nodiscard]] const runtime_function& get_function(const std::string_view &name) const override {
|
||||
for (const auto& scope : std::ranges::reverse_view(scopes)) {
|
||||
if (auto itx = scope.functions.find(std::string(name)); itx != scope.functions.end()) {
|
||||
return itx->second;
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error("Missing function");
|
||||
}
|
||||
};
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue