diff --git a/libjlx/modules/ast.cppm b/libjlx/modules/ast.cppm index 1db0d81..24a1a07 100644 --- a/libjlx/modules/ast.cppm +++ b/libjlx/modules/ast.cppm @@ -19,6 +19,7 @@ namespace jlx { VariableDeclaration, IfStatement, ReturnStatement, + Assignment }; export enum expression_type { @@ -96,6 +97,19 @@ namespace jlx { ~return_statement() override = default; }; + export struct assignment : public statement { + explicit assignment(const token& t): statement(Assignment, t) { + + } + + assignment(std::unique_ptr target, std::unique_ptr value, const token& t) : statement(Assignment, t), target(std::move(target)), value(std::move(value)) { + + } + + std::unique_ptr target; + std::unique_ptr value; + }; + export struct block : public statement { explicit block(const token& t) : statement(Block, t) { @@ -611,6 +625,35 @@ namespace jlx { return current->content; } + std::unique_ptr parse_assignment() { + auto& start = *current; + + auto ptr = current; + + if (ptr->type != Identifier) { + fail_invalid_token(*current); + } + + auto target = std::make_unique(start, start.content); + + ++ptr; + + if (ptr == last) { + return nullptr; + } + + if (ptr->type != Operator && ptr->content != "=") { + return nullptr; + } + + next(); + next(); + + auto expr = parse_expression(); + + return std::make_unique(std::move(target), std::move(expr), start); + } + std::unique_ptr parse_statement(bool top_level = false) { if (current == last) { return nullptr; @@ -628,6 +671,13 @@ namespace jlx { } } + if (current->type == token_type::Identifier) { + auto res = parse_assignment(); + if (res != nullptr) { + return res; + } + } + return parse_expression(); } diff --git a/libjlx/modules/interpreter.cppm b/libjlx/modules/interpreter.cppm index 00644f8..eabecc2 100644 --- a/libjlx/modules/interpreter.cppm +++ b/libjlx/modules/interpreter.cppm @@ -9,6 +9,8 @@ module; #include #include #include +#include +#include export module jlx:interpreter; import :ast; @@ -17,6 +19,12 @@ import :type_checker; namespace jlx { using runtime_value = std::variant; + enum class set_result { + success, + constant, + not_found + }; + struct interpreter_variable { std::string type; bool writable; @@ -52,6 +60,9 @@ namespace jlx { virtual interpreter_variable_scope& current_variable_scope() = 0; virtual bool is_in_function() = 0; virtual void inject_value(const std::string_view&, runtime_value) = 0; + virtual set_result set_value(const std::string_view&, runtime_value) = 0; + + virtual runtime_value type_cast(const std::string_view&, const runtime_value&) = 0; virtual ~base_interpreter() = default; }; @@ -383,7 +394,7 @@ namespace jlx { auto* lv = dynamic_cast(&e); if (lv == nullptr) { - throw interpreter_error("Internal interpreter error", e.t); + throw interpreter_error("Internal interpreter error #0", e.t); } auto& content = lv->t.content; @@ -576,7 +587,7 @@ namespace jlx { auto* svo = dynamic_cast(&e); if (svo == nullptr) { - throw interpreter_error("Internal interpreter error", e.t); + throw interpreter_error("Internal interpreter error #1", e.t); } auto val = eval_expression(*svo->operand); @@ -604,7 +615,7 @@ namespace jlx { auto* dvo = dynamic_cast(&e); if (dvo == nullptr) { - throw interpreter_error("Internal interpreter error", e.t); + throw interpreter_error("Internal interpreter error #2", e.t); } if (dvo->operator_token.content == "||") { @@ -637,7 +648,7 @@ namespace jlx { auto* call = dynamic_cast(&e); if (call == nullptr) { - throw interpreter_error("Internal interpreter error", e.t); + throw interpreter_error("Internal interpreter error #3", e.t); } auto it = functions.find(call->function_name); @@ -652,7 +663,7 @@ namespace jlx { auto arg = eval_expression(*a); if (!arg.has_value()) { - throw interpreter_error("Internal interpreter error", a->t); + throw interpreter_error("Internal interpreter error #4", a->t); } args.emplace_back(arg.value()); } @@ -664,7 +675,7 @@ namespace jlx { auto* id = dynamic_cast(&e); if (id == nullptr) { - throw interpreter_error("Internal interpreter error", e.t); + throw interpreter_error("Internal interpreter error #5", e.t); } for (auto& s : std::ranges::reverse_view(scopes)) { @@ -680,13 +691,111 @@ namespace jlx { } } + runtime_value type_cast(const std::string_view& target_type, const runtime_value& value) override { + auto src_type = get_type_for_value(value); + + if (src_type == target_type) { + return value; + } + + if (!type_checker_utils::is_convertible_to(src_type, target_type)) { + throw std::runtime_error(std::format("Cannot convert '{}' to '{}'", src_type, target_type).c_str()); + } + + if (target_type == "u16") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + } + if (target_type == "i16") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + + if (src_type == "i8") { + return static_cast(std::get(value)); + } + } + if (target_type == "u32") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + + if (src_type == "u16") { + return static_cast(std::get(value)); + } + } + if (target_type == "i32") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + + if (src_type == "u16") { + return static_cast(std::get(value)); + } + + if (src_type == "i8") { + return static_cast(std::get(value)); + } + + if (src_type == "i16") { + return static_cast(std::get(value)); + } + } + if (target_type == "u64") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + + if (src_type == "u16") { + return static_cast(std::get(value)); + } + + if (src_type == "u32") { + return static_cast(std::get(value)); + } + } + if (target_type == "i64") { + if (src_type == "u8") { + return static_cast(std::get(value)); + } + + if (src_type == "u16") { + return static_cast(std::get(value)); + } + + if (src_type == "u32") { + return static_cast(std::get(value)); + } + + if (src_type == "i8") { + return static_cast(std::get(value)); + } + + if (src_type == "i16") { + return static_cast(std::get(value)); + } + + if (src_type == "i32") { + return static_cast(std::get(value)); + } + } + if (target_type == "f64") { + if (src_type == "f32") { + return static_cast(std::get(value)); + } + } + + throw std::runtime_error(std::format("Conversion from '{}' to '{}' not implemented", src_type, target_type).c_str()); + } + std::optional execute_statement(const statement& s) override { switch (s.type) { case Expression: { auto* ex = dynamic_cast(&s); if (ex == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #6", s.t); } return eval_expression(*ex); @@ -696,7 +805,7 @@ namespace jlx { auto* b = dynamic_cast(&s); if (b == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #7", s.t); } push_variable_scope(); @@ -721,7 +830,7 @@ namespace jlx { auto* r = dynamic_cast(&s); if (r == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #8", s.t); } auto val = r->expression != nullptr ? eval_expression(*r->expression) : std::nullopt; @@ -735,7 +844,7 @@ namespace jlx { auto* r = dynamic_cast(&s); if (r == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #9", s.t); } for (auto& st : r->statements) { @@ -747,7 +856,7 @@ namespace jlx { auto* d = dynamic_cast(&s); if (d == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #10", s.t); } if (functions.find(d->name) != functions.end()) { @@ -763,7 +872,7 @@ namespace jlx { auto* d = dynamic_cast(&s); if (d == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #11", s.t); } auto& sc = current_variable_scope(); @@ -775,7 +884,7 @@ namespace jlx { auto val = d->initial_expression != nullptr ? eval_expression(*d->initial_expression) : throw std::runtime_error("Non-initialized variables are not supported yet."); //TODO: This sc.variables.insert_or_assign(d->name, interpreter_variable { - d->type.has_value() ? d->type.value() : "void", + d->type.has_value() ? d->type.value() : (val.has_value() ? std::string(get_type_for_value(val.value())) : "void"), !d->constant, val.has_value() ? val.value() : runtime_value(0), }); @@ -787,7 +896,7 @@ namespace jlx { auto* is = dynamic_cast(&s); if (is == nullptr) { - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter errori #12", s.t); } auto condition = eval_expression(*is->condition); @@ -803,9 +912,43 @@ namespace jlx { if (std::get<11>(condition.value())) { execute_statement(*is->block); } + + return std::nullopt; + } break; + case Assignment: { + auto* as = dynamic_cast(&s); + + if (as == nullptr) { + throw interpreter_error("Internal interpreter error #13", s.t); + } + + auto* target = dynamic_cast(as->target.get()); + + if (target == nullptr) { + throw interpreter_error("Invalid target expression for assignment", as->target->t); + } + + auto& name = target->name; + + auto val = eval_expression(*as->value); + + if (!val.has_value()) { + throw interpreter_error("Assignment with no value", as->value->t); + } + + auto set_res = set_value(name, val.value()); + + switch (set_res) { + case set_result::constant: + throw interpreter_error(std::format("Cannot assign to constant value '{}'", name), as->t); + case set_result::not_found: + throw interpreter_error(std::format("Symbol '{}' not found", name), as->t); + default: + return std::nullopt; + } } break; } - throw interpreter_error("Internal interpreter error", s.t); + throw interpreter_error("Internal interpreter error #14", s.t); } void push_variable_scope() override { @@ -840,9 +983,24 @@ namespace jlx { }); } + set_result set_value(const std::string_view& name, runtime_value val) override { + for(auto& v : std::ranges::reverse_view(scopes)) { + auto it = v.variables.find(std::string(name)); + if (it != v.variables.end()) { + if (it->second.writable) { + it->second.value = type_cast(it->second.type, val); + return set_result::success; + } + return set_result::constant; + } + } + + return set_result::not_found; + } + interpreter_function_scope& current_function_scope() override { if (function_scopes.empty()) { - throw std::runtime_error("Internal interpreter error"); + throw std::runtime_error("Internal interpreter error #15"); } return function_scopes.back(); @@ -856,4 +1014,4 @@ namespace jlx { return !function_scopes.empty(); } }; -} \ No newline at end of file +} diff --git a/libjlx/modules/type_checker.cppm b/libjlx/modules/type_checker.cppm index 1722614..f2e8828 100644 --- a/libjlx/modules/type_checker.cppm +++ b/libjlx/modules/type_checker.cppm @@ -6,6 +6,8 @@ module; #include #include #include +#include +#include export module jlx:type_checker; @@ -466,11 +468,11 @@ namespace jlx { } - if (type == "comptime_int") { - type = "i64"; - } else if (type == "comptime_float") { - type = "f64"; - } + if (type == "comptime_int") { + type = "i64"; + } else if (type == "comptime_float") { + type = "f64"; + } if (!type.has_value()) { throw std::runtime_error("Cannot infer variable declaration type"); @@ -533,6 +535,22 @@ namespace jlx { } } break; + case Assignment: { + auto* as = dynamic_cast(&s); + + if (!as->target->evaluated_type.has_value()) { + type_checker_utils::evaluate_expression_type(*as->target, *this); + } + + if (!as->value->evaluated_type.has_value()) { + type_checker_utils::evaluate_expression_type(*as->value, *this); + } + + if (!type_checker_utils::is_convertible_to(as->value->evaluated_type.value(), as->target->evaluated_type.value())) { + throw type_error(as->target->t, as->value->evaluated_type.value()); + } + } + break; } } @@ -614,4 +632,4 @@ namespace jlx { }); } }; -} \ No newline at end of file +}