[general] Parse, type-check and interpret 'assignment' statement

This commit is contained in:
jstefanelli 2025-06-10 17:20:28 +02:00
parent 7c1612d9c2
commit f90a2ed8f1
Signed by: jstefanelli
GPG key ID: 60EDE2437640D2AA
3 changed files with 249 additions and 23 deletions

View file

@ -19,6 +19,7 @@ namespace jlx {
VariableDeclaration, VariableDeclaration,
IfStatement, IfStatement,
ReturnStatement, ReturnStatement,
Assignment
}; };
export enum expression_type { export enum expression_type {
@ -96,6 +97,19 @@ namespace jlx {
~return_statement() override = default; ~return_statement() override = default;
}; };
export struct assignment : public statement {
explicit assignment(const token& t): statement(Assignment, t) {
}
assignment(std::unique_ptr<expression> target, std::unique_ptr<expression> value, const token& t) : statement(Assignment, t), target(std::move(target)), value(std::move(value)) {
}
std::unique_ptr<expression> target;
std::unique_ptr<expression> value;
};
export struct block : public statement { export struct block : public statement {
explicit block(const token& t) : statement(Block, t) { explicit block(const token& t) : statement(Block, t) {
@ -611,6 +625,35 @@ namespace jlx {
return current->content; return current->content;
} }
std::unique_ptr<statement> parse_assignment() {
auto& start = *current;
auto ptr = current;
if (ptr->type != Identifier) {
fail_invalid_token(*current);
}
auto target = std::make_unique<identifier_expression>(start, start.content);
++ptr;
if (ptr == last) {
return nullptr;
}
if (ptr->type != Operator && ptr->content != "=") {
return nullptr;
}
next();
next();
auto expr = parse_expression();
return std::make_unique<assignment>(std::move(target), std::move(expr), start);
}
std::unique_ptr<statement> parse_statement(bool top_level = false) { std::unique_ptr<statement> parse_statement(bool top_level = false) {
if (current == last) { if (current == last) {
return nullptr; return nullptr;
@ -628,6 +671,13 @@ namespace jlx {
} }
} }
if (current->type == token_type::Identifier) {
auto res = parse_assignment();
if (res != nullptr) {
return res;
}
}
return parse_expression(); return parse_expression();
} }

View file

@ -9,6 +9,8 @@ module;
#include <ranges> #include <ranges>
#include <format> #include <format>
#include <iostream> #include <iostream>
#include <functional>
#include <memory>
export module jlx:interpreter; export module jlx:interpreter;
import :ast; import :ast;
@ -17,6 +19,12 @@ import :type_checker;
namespace jlx { namespace jlx {
using runtime_value = std::variant<std::string, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double, bool>; using runtime_value = std::variant<std::string, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double, bool>;
enum class set_result {
success,
constant,
not_found
};
struct interpreter_variable { struct interpreter_variable {
std::string type; std::string type;
bool writable; bool writable;
@ -52,6 +60,9 @@ namespace jlx {
virtual interpreter_variable_scope& current_variable_scope() = 0; virtual interpreter_variable_scope& current_variable_scope() = 0;
virtual bool is_in_function() = 0; virtual bool is_in_function() = 0;
virtual void inject_value(const std::string_view&, runtime_value) = 0; virtual void inject_value(const std::string_view&, runtime_value) = 0;
virtual set_result set_value(const std::string_view&, runtime_value) = 0;
virtual runtime_value type_cast(const std::string_view&, const runtime_value&) = 0;
virtual ~base_interpreter() = default; virtual ~base_interpreter() = default;
}; };
@ -383,7 +394,7 @@ namespace jlx {
auto* lv = dynamic_cast<const literal_value*>(&e); auto* lv = dynamic_cast<const literal_value*>(&e);
if (lv == nullptr) { if (lv == nullptr) {
throw interpreter_error("Internal interpreter error", e.t); throw interpreter_error("Internal interpreter error #0", e.t);
} }
auto& content = lv->t.content; auto& content = lv->t.content;
@ -576,7 +587,7 @@ namespace jlx {
auto* svo = dynamic_cast<const single_operation*>(&e); auto* svo = dynamic_cast<const single_operation*>(&e);
if (svo == nullptr) { if (svo == nullptr) {
throw interpreter_error("Internal interpreter error", e.t); throw interpreter_error("Internal interpreter error #1", e.t);
} }
auto val = eval_expression(*svo->operand); auto val = eval_expression(*svo->operand);
@ -604,7 +615,7 @@ namespace jlx {
auto* dvo = dynamic_cast<const dual_operation*>(&e); auto* dvo = dynamic_cast<const dual_operation*>(&e);
if (dvo == nullptr) { if (dvo == nullptr) {
throw interpreter_error("Internal interpreter error", e.t); throw interpreter_error("Internal interpreter error #2", e.t);
} }
if (dvo->operator_token.content == "||") { if (dvo->operator_token.content == "||") {
@ -637,7 +648,7 @@ namespace jlx {
auto* call = dynamic_cast<const function_call*>(&e); auto* call = dynamic_cast<const function_call*>(&e);
if (call == nullptr) { if (call == nullptr) {
throw interpreter_error("Internal interpreter error", e.t); throw interpreter_error("Internal interpreter error #3", e.t);
} }
auto it = functions.find(call->function_name); auto it = functions.find(call->function_name);
@ -652,7 +663,7 @@ namespace jlx {
auto arg = eval_expression(*a); auto arg = eval_expression(*a);
if (!arg.has_value()) { if (!arg.has_value()) {
throw interpreter_error("Internal interpreter error", a->t); throw interpreter_error("Internal interpreter error #4", a->t);
} }
args.emplace_back(arg.value()); args.emplace_back(arg.value());
} }
@ -664,7 +675,7 @@ namespace jlx {
auto* id = dynamic_cast<const identifier_expression*>(&e); auto* id = dynamic_cast<const identifier_expression*>(&e);
if (id == nullptr) { if (id == nullptr) {
throw interpreter_error("Internal interpreter error", e.t); throw interpreter_error("Internal interpreter error #5", e.t);
} }
for (auto& s : std::ranges::reverse_view(scopes)) { for (auto& s : std::ranges::reverse_view(scopes)) {
@ -680,13 +691,111 @@ namespace jlx {
} }
} }
runtime_value type_cast(const std::string_view& target_type, const runtime_value& value) override {
auto src_type = get_type_for_value(value);
if (src_type == target_type) {
return value;
}
if (!type_checker_utils::is_convertible_to(src_type, target_type)) {
throw std::runtime_error(std::format("Cannot convert '{}' to '{}'", src_type, target_type).c_str());
}
if (target_type == "u16") {
if (src_type == "u8") {
return static_cast<uint16_t>(std::get<uint8_t>(value));
}
}
if (target_type == "i16") {
if (src_type == "u8") {
return static_cast<int16_t>(std::get<uint8_t>(value));
}
if (src_type == "i8") {
return static_cast<int16_t>(std::get<int8_t>(value));
}
}
if (target_type == "u32") {
if (src_type == "u8") {
return static_cast<uint32_t>(std::get<uint8_t>(value));
}
if (src_type == "u16") {
return static_cast<uint32_t>(std::get<uint16_t>(value));
}
}
if (target_type == "i32") {
if (src_type == "u8") {
return static_cast<int32_t>(std::get<uint8_t>(value));
}
if (src_type == "u16") {
return static_cast<int32_t>(std::get<uint16_t>(value));
}
if (src_type == "i8") {
return static_cast<int32_t>(std::get<int8_t>(value));
}
if (src_type == "i16") {
return static_cast<int32_t>(std::get<int16_t>(value));
}
}
if (target_type == "u64") {
if (src_type == "u8") {
return static_cast<uint64_t>(std::get<uint8_t>(value));
}
if (src_type == "u16") {
return static_cast<uint64_t>(std::get<uint16_t>(value));
}
if (src_type == "u32") {
return static_cast<uint64_t>(std::get<uint32_t>(value));
}
}
if (target_type == "i64") {
if (src_type == "u8") {
return static_cast<int64_t>(std::get<uint8_t>(value));
}
if (src_type == "u16") {
return static_cast<int64_t>(std::get<uint16_t>(value));
}
if (src_type == "u32") {
return static_cast<int64_t>(std::get<uint32_t>(value));
}
if (src_type == "i8") {
return static_cast<int64_t>(std::get<int8_t>(value));
}
if (src_type == "i16") {
return static_cast<int64_t>(std::get<int16_t>(value));
}
if (src_type == "i32") {
return static_cast<int64_t>(std::get<int32_t>(value));
}
}
if (target_type == "f64") {
if (src_type == "f32") {
return static_cast<double>(std::get<float>(value));
}
}
throw std::runtime_error(std::format("Conversion from '{}' to '{}' not implemented", src_type, target_type).c_str());
}
std::optional<runtime_value> execute_statement(const statement& s) override { std::optional<runtime_value> execute_statement(const statement& s) override {
switch (s.type) { switch (s.type) {
case Expression: { case Expression: {
auto* ex = dynamic_cast<const expression*>(&s); auto* ex = dynamic_cast<const expression*>(&s);
if (ex == nullptr) { if (ex == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #6", s.t);
} }
return eval_expression(*ex); return eval_expression(*ex);
@ -696,7 +805,7 @@ namespace jlx {
auto* b = dynamic_cast<const block*>(&s); auto* b = dynamic_cast<const block*>(&s);
if (b == nullptr) { if (b == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #7", s.t);
} }
push_variable_scope(); push_variable_scope();
@ -721,7 +830,7 @@ namespace jlx {
auto* r = dynamic_cast<const return_statement*>(&s); auto* r = dynamic_cast<const return_statement*>(&s);
if (r == nullptr) { if (r == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #8", s.t);
} }
auto val = r->expression != nullptr ? eval_expression(*r->expression) : std::nullopt; auto val = r->expression != nullptr ? eval_expression(*r->expression) : std::nullopt;
@ -735,7 +844,7 @@ namespace jlx {
auto* r = dynamic_cast<const root_statement*>(&s); auto* r = dynamic_cast<const root_statement*>(&s);
if (r == nullptr) { if (r == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #9", s.t);
} }
for (auto& st : r->statements) { for (auto& st : r->statements) {
@ -747,7 +856,7 @@ namespace jlx {
auto* d = dynamic_cast<const function_declaration*>(&s); auto* d = dynamic_cast<const function_declaration*>(&s);
if (d == nullptr) { if (d == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #10", s.t);
} }
if (functions.find(d->name) != functions.end()) { if (functions.find(d->name) != functions.end()) {
@ -763,7 +872,7 @@ namespace jlx {
auto* d = dynamic_cast<const variable_declaration*>(&s); auto* d = dynamic_cast<const variable_declaration*>(&s);
if (d == nullptr) { if (d == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #11", s.t);
} }
auto& sc = current_variable_scope(); auto& sc = current_variable_scope();
@ -775,7 +884,7 @@ namespace jlx {
auto val = d->initial_expression != nullptr ? eval_expression(*d->initial_expression) : throw std::runtime_error("Non-initialized variables are not supported yet."); //TODO: This auto val = d->initial_expression != nullptr ? eval_expression(*d->initial_expression) : throw std::runtime_error("Non-initialized variables are not supported yet."); //TODO: This
sc.variables.insert_or_assign(d->name, interpreter_variable { sc.variables.insert_or_assign(d->name, interpreter_variable {
d->type.has_value() ? d->type.value() : "void", d->type.has_value() ? d->type.value() : (val.has_value() ? std::string(get_type_for_value(val.value())) : "void"),
!d->constant, !d->constant,
val.has_value() ? val.value() : runtime_value(0), val.has_value() ? val.value() : runtime_value(0),
}); });
@ -787,7 +896,7 @@ namespace jlx {
auto* is = dynamic_cast<const if_statement*>(&s); auto* is = dynamic_cast<const if_statement*>(&s);
if (is == nullptr) { if (is == nullptr) {
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter errori #12", s.t);
} }
auto condition = eval_expression(*is->condition); auto condition = eval_expression(*is->condition);
@ -803,9 +912,43 @@ namespace jlx {
if (std::get<11>(condition.value())) { if (std::get<11>(condition.value())) {
execute_statement(*is->block); execute_statement(*is->block);
} }
return std::nullopt;
} break;
case Assignment: {
auto* as = dynamic_cast<const assignment*>(&s);
if (as == nullptr) {
throw interpreter_error("Internal interpreter error #13", s.t);
}
auto* target = dynamic_cast<const identifier_expression*>(as->target.get());
if (target == nullptr) {
throw interpreter_error("Invalid target expression for assignment", as->target->t);
}
auto& name = target->name;
auto val = eval_expression(*as->value);
if (!val.has_value()) {
throw interpreter_error("Assignment with no value", as->value->t);
}
auto set_res = set_value(name, val.value());
switch (set_res) {
case set_result::constant:
throw interpreter_error(std::format("Cannot assign to constant value '{}'", name), as->t);
case set_result::not_found:
throw interpreter_error(std::format("Symbol '{}' not found", name), as->t);
default:
return std::nullopt;
}
} break; } break;
} }
throw interpreter_error("Internal interpreter error", s.t); throw interpreter_error("Internal interpreter error #14", s.t);
} }
void push_variable_scope() override { void push_variable_scope() override {
@ -840,9 +983,24 @@ namespace jlx {
}); });
} }
set_result set_value(const std::string_view& name, runtime_value val) override {
for(auto& v : std::ranges::reverse_view(scopes)) {
auto it = v.variables.find(std::string(name));
if (it != v.variables.end()) {
if (it->second.writable) {
it->second.value = type_cast(it->second.type, val);
return set_result::success;
}
return set_result::constant;
}
}
return set_result::not_found;
}
interpreter_function_scope& current_function_scope() override { interpreter_function_scope& current_function_scope() override {
if (function_scopes.empty()) { if (function_scopes.empty()) {
throw std::runtime_error("Internal interpreter error"); throw std::runtime_error("Internal interpreter error #15");
} }
return function_scopes.back(); return function_scopes.back();
@ -856,4 +1014,4 @@ namespace jlx {
return !function_scopes.empty(); return !function_scopes.empty();
} }
}; };
} }

View file

@ -6,6 +6,8 @@ module;
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <algorithm> #include <algorithm>
#include <memory>
#include <unordered_map>
export module jlx:type_checker; export module jlx:type_checker;
@ -466,11 +468,11 @@ namespace jlx {
} }
if (type == "comptime_int") { if (type == "comptime_int") {
type = "i64"; type = "i64";
} else if (type == "comptime_float") { } else if (type == "comptime_float") {
type = "f64"; type = "f64";
} }
if (!type.has_value()) { if (!type.has_value()) {
throw std::runtime_error("Cannot infer variable declaration type"); throw std::runtime_error("Cannot infer variable declaration type");
@ -533,6 +535,22 @@ namespace jlx {
} }
} }
break; break;
case Assignment: {
auto* as = dynamic_cast<assignment*>(&s);
if (!as->target->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*as->target, *this);
}
if (!as->value->evaluated_type.has_value()) {
type_checker_utils::evaluate_expression_type(*as->value, *this);
}
if (!type_checker_utils::is_convertible_to(as->value->evaluated_type.value(), as->target->evaluated_type.value())) {
throw type_error(as->target->t, as->value->evaluated_type.value());
}
}
break;
} }
} }
@ -614,4 +632,4 @@ namespace jlx {
}); });
} }
}; };
} }