[general] Parse, type-check and interpret 'assignment' statement
This commit is contained in:
parent
7c1612d9c2
commit
f90a2ed8f1
3 changed files with 249 additions and 23 deletions
|
|
@ -19,6 +19,7 @@ namespace jlx {
|
|||
VariableDeclaration,
|
||||
IfStatement,
|
||||
ReturnStatement,
|
||||
Assignment
|
||||
};
|
||||
|
||||
export enum expression_type {
|
||||
|
|
@ -96,6 +97,19 @@ namespace jlx {
|
|||
~return_statement() override = default;
|
||||
};
|
||||
|
||||
export struct assignment : public statement {
|
||||
explicit assignment(const token& t): statement(Assignment, t) {
|
||||
|
||||
}
|
||||
|
||||
assignment(std::unique_ptr<expression> target, std::unique_ptr<expression> value, const token& t) : statement(Assignment, t), target(std::move(target)), value(std::move(value)) {
|
||||
|
||||
}
|
||||
|
||||
std::unique_ptr<expression> target;
|
||||
std::unique_ptr<expression> value;
|
||||
};
|
||||
|
||||
export struct block : public statement {
|
||||
explicit block(const token& t) : statement(Block, t) {
|
||||
|
||||
|
|
@ -611,6 +625,35 @@ namespace jlx {
|
|||
return current->content;
|
||||
}
|
||||
|
||||
std::unique_ptr<statement> parse_assignment() {
|
||||
auto& start = *current;
|
||||
|
||||
auto ptr = current;
|
||||
|
||||
if (ptr->type != Identifier) {
|
||||
fail_invalid_token(*current);
|
||||
}
|
||||
|
||||
auto target = std::make_unique<identifier_expression>(start, start.content);
|
||||
|
||||
++ptr;
|
||||
|
||||
if (ptr == last) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (ptr->type != Operator && ptr->content != "=") {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
next();
|
||||
next();
|
||||
|
||||
auto expr = parse_expression();
|
||||
|
||||
return std::make_unique<assignment>(std::move(target), std::move(expr), start);
|
||||
}
|
||||
|
||||
std::unique_ptr<statement> parse_statement(bool top_level = false) {
|
||||
if (current == last) {
|
||||
return nullptr;
|
||||
|
|
@ -628,6 +671,13 @@ namespace jlx {
|
|||
}
|
||||
}
|
||||
|
||||
if (current->type == token_type::Identifier) {
|
||||
auto res = parse_assignment();
|
||||
if (res != nullptr) {
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
return parse_expression();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ module;
|
|||
#include <ranges>
|
||||
#include <format>
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
export module jlx:interpreter;
|
||||
import :ast;
|
||||
|
|
@ -17,6 +19,12 @@ import :type_checker;
|
|||
namespace jlx {
|
||||
using runtime_value = std::variant<std::string, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, float, double, bool>;
|
||||
|
||||
enum class set_result {
|
||||
success,
|
||||
constant,
|
||||
not_found
|
||||
};
|
||||
|
||||
struct interpreter_variable {
|
||||
std::string type;
|
||||
bool writable;
|
||||
|
|
@ -52,6 +60,9 @@ namespace jlx {
|
|||
virtual interpreter_variable_scope& current_variable_scope() = 0;
|
||||
virtual bool is_in_function() = 0;
|
||||
virtual void inject_value(const std::string_view&, runtime_value) = 0;
|
||||
virtual set_result set_value(const std::string_view&, runtime_value) = 0;
|
||||
|
||||
virtual runtime_value type_cast(const std::string_view&, const runtime_value&) = 0;
|
||||
|
||||
virtual ~base_interpreter() = default;
|
||||
};
|
||||
|
|
@ -383,7 +394,7 @@ namespace jlx {
|
|||
auto* lv = dynamic_cast<const literal_value*>(&e);
|
||||
|
||||
if (lv == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", e.t);
|
||||
throw interpreter_error("Internal interpreter error #0", e.t);
|
||||
}
|
||||
|
||||
auto& content = lv->t.content;
|
||||
|
|
@ -576,7 +587,7 @@ namespace jlx {
|
|||
auto* svo = dynamic_cast<const single_operation*>(&e);
|
||||
|
||||
if (svo == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", e.t);
|
||||
throw interpreter_error("Internal interpreter error #1", e.t);
|
||||
}
|
||||
|
||||
auto val = eval_expression(*svo->operand);
|
||||
|
|
@ -604,7 +615,7 @@ namespace jlx {
|
|||
auto* dvo = dynamic_cast<const dual_operation*>(&e);
|
||||
|
||||
if (dvo == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", e.t);
|
||||
throw interpreter_error("Internal interpreter error #2", e.t);
|
||||
}
|
||||
|
||||
if (dvo->operator_token.content == "||") {
|
||||
|
|
@ -637,7 +648,7 @@ namespace jlx {
|
|||
auto* call = dynamic_cast<const function_call*>(&e);
|
||||
|
||||
if (call == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", e.t);
|
||||
throw interpreter_error("Internal interpreter error #3", e.t);
|
||||
}
|
||||
|
||||
auto it = functions.find(call->function_name);
|
||||
|
|
@ -652,7 +663,7 @@ namespace jlx {
|
|||
auto arg = eval_expression(*a);
|
||||
|
||||
if (!arg.has_value()) {
|
||||
throw interpreter_error("Internal interpreter error", a->t);
|
||||
throw interpreter_error("Internal interpreter error #4", a->t);
|
||||
}
|
||||
args.emplace_back(arg.value());
|
||||
}
|
||||
|
|
@ -664,7 +675,7 @@ namespace jlx {
|
|||
auto* id = dynamic_cast<const identifier_expression*>(&e);
|
||||
|
||||
if (id == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", e.t);
|
||||
throw interpreter_error("Internal interpreter error #5", e.t);
|
||||
}
|
||||
|
||||
for (auto& s : std::ranges::reverse_view(scopes)) {
|
||||
|
|
@ -680,13 +691,111 @@ namespace jlx {
|
|||
}
|
||||
}
|
||||
|
||||
runtime_value type_cast(const std::string_view& target_type, const runtime_value& value) override {
|
||||
auto src_type = get_type_for_value(value);
|
||||
|
||||
if (src_type == target_type) {
|
||||
return value;
|
||||
}
|
||||
|
||||
if (!type_checker_utils::is_convertible_to(src_type, target_type)) {
|
||||
throw std::runtime_error(std::format("Cannot convert '{}' to '{}'", src_type, target_type).c_str());
|
||||
}
|
||||
|
||||
if (target_type == "u16") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<uint16_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "i16") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<int16_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i8") {
|
||||
return static_cast<int16_t>(std::get<int8_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "u32") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<uint32_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u16") {
|
||||
return static_cast<uint32_t>(std::get<uint16_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "i32") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<int32_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u16") {
|
||||
return static_cast<int32_t>(std::get<uint16_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i8") {
|
||||
return static_cast<int32_t>(std::get<int8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i16") {
|
||||
return static_cast<int32_t>(std::get<int16_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "u64") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<uint64_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u16") {
|
||||
return static_cast<uint64_t>(std::get<uint16_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u32") {
|
||||
return static_cast<uint64_t>(std::get<uint32_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "i64") {
|
||||
if (src_type == "u8") {
|
||||
return static_cast<int64_t>(std::get<uint8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u16") {
|
||||
return static_cast<int64_t>(std::get<uint16_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "u32") {
|
||||
return static_cast<int64_t>(std::get<uint32_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i8") {
|
||||
return static_cast<int64_t>(std::get<int8_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i16") {
|
||||
return static_cast<int64_t>(std::get<int16_t>(value));
|
||||
}
|
||||
|
||||
if (src_type == "i32") {
|
||||
return static_cast<int64_t>(std::get<int32_t>(value));
|
||||
}
|
||||
}
|
||||
if (target_type == "f64") {
|
||||
if (src_type == "f32") {
|
||||
return static_cast<double>(std::get<float>(value));
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error(std::format("Conversion from '{}' to '{}' not implemented", src_type, target_type).c_str());
|
||||
}
|
||||
|
||||
std::optional<runtime_value> execute_statement(const statement& s) override {
|
||||
switch (s.type) {
|
||||
case Expression: {
|
||||
auto* ex = dynamic_cast<const expression*>(&s);
|
||||
|
||||
if (ex == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #6", s.t);
|
||||
}
|
||||
|
||||
return eval_expression(*ex);
|
||||
|
|
@ -696,7 +805,7 @@ namespace jlx {
|
|||
auto* b = dynamic_cast<const block*>(&s);
|
||||
|
||||
if (b == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #7", s.t);
|
||||
}
|
||||
|
||||
push_variable_scope();
|
||||
|
|
@ -721,7 +830,7 @@ namespace jlx {
|
|||
auto* r = dynamic_cast<const return_statement*>(&s);
|
||||
|
||||
if (r == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #8", s.t);
|
||||
}
|
||||
|
||||
auto val = r->expression != nullptr ? eval_expression(*r->expression) : std::nullopt;
|
||||
|
|
@ -735,7 +844,7 @@ namespace jlx {
|
|||
auto* r = dynamic_cast<const root_statement*>(&s);
|
||||
|
||||
if (r == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #9", s.t);
|
||||
}
|
||||
|
||||
for (auto& st : r->statements) {
|
||||
|
|
@ -747,7 +856,7 @@ namespace jlx {
|
|||
auto* d = dynamic_cast<const function_declaration*>(&s);
|
||||
|
||||
if (d == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #10", s.t);
|
||||
}
|
||||
|
||||
if (functions.find(d->name) != functions.end()) {
|
||||
|
|
@ -763,7 +872,7 @@ namespace jlx {
|
|||
auto* d = dynamic_cast<const variable_declaration*>(&s);
|
||||
|
||||
if (d == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #11", s.t);
|
||||
}
|
||||
|
||||
auto& sc = current_variable_scope();
|
||||
|
|
@ -775,7 +884,7 @@ namespace jlx {
|
|||
auto val = d->initial_expression != nullptr ? eval_expression(*d->initial_expression) : throw std::runtime_error("Non-initialized variables are not supported yet."); //TODO: This
|
||||
|
||||
sc.variables.insert_or_assign(d->name, interpreter_variable {
|
||||
d->type.has_value() ? d->type.value() : "void",
|
||||
d->type.has_value() ? d->type.value() : (val.has_value() ? std::string(get_type_for_value(val.value())) : "void"),
|
||||
!d->constant,
|
||||
val.has_value() ? val.value() : runtime_value(0),
|
||||
});
|
||||
|
|
@ -787,7 +896,7 @@ namespace jlx {
|
|||
auto* is = dynamic_cast<const if_statement*>(&s);
|
||||
|
||||
if (is == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter errori #12", s.t);
|
||||
}
|
||||
|
||||
auto condition = eval_expression(*is->condition);
|
||||
|
|
@ -803,9 +912,43 @@ namespace jlx {
|
|||
if (std::get<11>(condition.value())) {
|
||||
execute_statement(*is->block);
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
} break;
|
||||
case Assignment: {
|
||||
auto* as = dynamic_cast<const assignment*>(&s);
|
||||
|
||||
if (as == nullptr) {
|
||||
throw interpreter_error("Internal interpreter error #13", s.t);
|
||||
}
|
||||
|
||||
auto* target = dynamic_cast<const identifier_expression*>(as->target.get());
|
||||
|
||||
if (target == nullptr) {
|
||||
throw interpreter_error("Invalid target expression for assignment", as->target->t);
|
||||
}
|
||||
|
||||
auto& name = target->name;
|
||||
|
||||
auto val = eval_expression(*as->value);
|
||||
|
||||
if (!val.has_value()) {
|
||||
throw interpreter_error("Assignment with no value", as->value->t);
|
||||
}
|
||||
|
||||
auto set_res = set_value(name, val.value());
|
||||
|
||||
switch (set_res) {
|
||||
case set_result::constant:
|
||||
throw interpreter_error(std::format("Cannot assign to constant value '{}'", name), as->t);
|
||||
case set_result::not_found:
|
||||
throw interpreter_error(std::format("Symbol '{}' not found", name), as->t);
|
||||
default:
|
||||
return std::nullopt;
|
||||
}
|
||||
} break;
|
||||
}
|
||||
throw interpreter_error("Internal interpreter error", s.t);
|
||||
throw interpreter_error("Internal interpreter error #14", s.t);
|
||||
}
|
||||
|
||||
void push_variable_scope() override {
|
||||
|
|
@ -840,9 +983,24 @@ namespace jlx {
|
|||
});
|
||||
}
|
||||
|
||||
set_result set_value(const std::string_view& name, runtime_value val) override {
|
||||
for(auto& v : std::ranges::reverse_view(scopes)) {
|
||||
auto it = v.variables.find(std::string(name));
|
||||
if (it != v.variables.end()) {
|
||||
if (it->second.writable) {
|
||||
it->second.value = type_cast(it->second.type, val);
|
||||
return set_result::success;
|
||||
}
|
||||
return set_result::constant;
|
||||
}
|
||||
}
|
||||
|
||||
return set_result::not_found;
|
||||
}
|
||||
|
||||
interpreter_function_scope& current_function_scope() override {
|
||||
if (function_scopes.empty()) {
|
||||
throw std::runtime_error("Internal interpreter error");
|
||||
throw std::runtime_error("Internal interpreter error #15");
|
||||
}
|
||||
|
||||
return function_scopes.back();
|
||||
|
|
@ -856,4 +1014,4 @@ namespace jlx {
|
|||
return !function_scopes.empty();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ module;
|
|||
#include <string>
|
||||
#include <string_view>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
export module jlx:type_checker;
|
||||
|
||||
|
|
@ -466,11 +468,11 @@ namespace jlx {
|
|||
}
|
||||
|
||||
|
||||
if (type == "comptime_int") {
|
||||
type = "i64";
|
||||
} else if (type == "comptime_float") {
|
||||
type = "f64";
|
||||
}
|
||||
if (type == "comptime_int") {
|
||||
type = "i64";
|
||||
} else if (type == "comptime_float") {
|
||||
type = "f64";
|
||||
}
|
||||
|
||||
if (!type.has_value()) {
|
||||
throw std::runtime_error("Cannot infer variable declaration type");
|
||||
|
|
@ -533,6 +535,22 @@ namespace jlx {
|
|||
}
|
||||
}
|
||||
break;
|
||||
case Assignment: {
|
||||
auto* as = dynamic_cast<assignment*>(&s);
|
||||
|
||||
if (!as->target->evaluated_type.has_value()) {
|
||||
type_checker_utils::evaluate_expression_type(*as->target, *this);
|
||||
}
|
||||
|
||||
if (!as->value->evaluated_type.has_value()) {
|
||||
type_checker_utils::evaluate_expression_type(*as->value, *this);
|
||||
}
|
||||
|
||||
if (!type_checker_utils::is_convertible_to(as->value->evaluated_type.value(), as->target->evaluated_type.value())) {
|
||||
throw type_error(as->target->t, as->value->evaluated_type.value());
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -614,4 +632,4 @@ namespace jlx {
|
|||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue