jlx/libjlx/modules/ast.cppm

223 lines
4.4 KiB
C++

module;
#include <stdexcept>
#include <format>
#include <iterator>
#include <vector>
#include <memory>
export module jlx:ast;
import :tokenizer;
namespace jlx {
export enum ast_type {
Root,
Expression,
Block,
FunctionDeclaration,
SimpleIdentifier,
LiteralValue
};
export template<class T>
concept token_iterator = requires() {
requires std::same_as<decltype(typename T::value_type), jlx::token>;
std::bidirectional_iterator<T>;
};
struct statement {
ast_type type;
statement(ast_type type) : type(type) {
}
};
struct root_statement : public statement {
root_statement() : statement(Root) {
}
std::vector<std::unique_ptr<statement>> statements;
};
struct expression : public statement {
expression() : statement(Expression) {
}
};
struct block : public statement {
block() : statement(Block) {
}
std::vector<std::unique_ptr<statement>> statements;
};
struct function_parameter {
std::string name;
std::string type;
};
struct function_declaration : public statement {
function_declaration() : statement(FunctionDeclaration) {
}
std::string name;
std::vector<function_parameter> parameters;
std::optional<std::string> return_type;
std::unique_ptr<block> body;
};
export template<token_iterator T, std::sentinel_for<T> E>
class parser {
T current;
E last;
inline void fail_invalid_token(const token& t) {
throw std::runtime_error(std::format("Invalid token {} at {}:{}", t.content, t.line, t.col).c_str());
}
inline void fail_invalid_eof() {
throw std::runtime_error("Unexpected end-of-file");
}
void next() {
current++;
if (current == last) {
fail_invalid_eof();
}
}
std::unique_ptr<block> parse_block() {
if (current->type != Punctuation || current->content != "{") {
fail_invalid_token(*current);
}
next();
while(current->type != Punctuation && current->type != "}") {
}
}
std::unique_ptr<function_declaration> parse_function() {
if (current->type != Keyword || current->content != "fun") {
fail_invalid_token(*current);
}
next();
if (current->type != Identifier) {
fail_invalid_token(*current);
}
std::string function_name = current->content;
next();
if (current->type != Punctuation || current->content != "(") {
fail_invalid_token(*current);
}
next();
std::vector<function_parameter> params;
std::optional<std::string> return_type;
bool first = true;
while (current->type != Punctuation && current->content != ")") {
if (!first) {
if(current->type != Punctuation || current->content != ",") {
fail_invalid_token(*current);
} else {
next();
}
}
std::string name;
if (current->type != Identifier) {
fail_invalid_token(*current);
}
name = current->content;
next();
if (current->type != Punctuation || current->content != ":") {
fail_invalid_token(*current);
}
next();
auto param_type = parse_type(current, last);
params.push_back(std::move(name), std::move(param_type));
next();
first = false;
}
next();
if (current->type == Punctuation && current->content == ":") {
next();
return_type = parse_type(current, last);
next();
}
auto block = parse_block();
return std::make_unique<function_declaration>(std::move(function_name), std::move(params), std::move(return_type), std::move(block));
}
std::string parse_type(){
if (current->type != Identifier) {
fail_invalid_token(*current);
}
return current->content;
}
std::unique_ptr<statement> parse_top_level_statement() {
if (current == last) {
return nullptr;
}
if (current->type == token_type::Keyword) {
switch(current->content) {
case "let":
case "var":
parse_variable_declaration(current, last);
break;
case "if":
parse_if_statement(current, last);
break;
case "fun":
return parse_function(current, last);
}
}
}
public:
parser(T current, E last) : current(current), last(last) {
}
std::unique_ptr<statement> parse() {
std::vector<std::unique_ptr<statement>> top_level_statements;
while(current != last) {
auto s = parse_top_level_statement();
if (s == nullptr) {
throw std::runtime_error("No statement parsed...");
}
top_level_statements.push_back(std::move(s));
}
return std::make_unique<root_statement>(top_level_statements);
}
};
}