《LLVM IR 学习手记(四):控制流与块语句的实现与基本块解析》

72 阅读28分钟

本文将带大家实现编译器中基础的 if 语句以及块语句的功能。

1. 实现基础的 If 语句功能

1.1 测试文件

expr_01.txt

int a = 3;
int b = 5;
if (a)
  a = 2;
else 
  a = 4;
a * b - 4;

expr_02.txt

int a = 3;
int b = 5;
if (a - 3)
  a = 2;
else 
  a = 4;
a * b - 4;

expr_03.txt

int a = 3;
int b = 5;
if (a)
  a = 2;
a * b - 4;

1.2 词法分析器 (Lexer)

对于词法分析器,首先需要增加对 ifelse两种 token 类型的识别。

文法定义

ebnf.txt

prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt 
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
block-stmt : "{" stmt* "}"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)* 
mult-expr : primary-expr (("*" | "/") primary-expr)* 
primary-expr : identifier | number | "(" expr ")" 
number: ([0-9])+ 
identifier : (a-zA-Z_)(a-zA-Z0-9_)*

实现代码

lexer.h

#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"

// char stream -> token

enum class TokenType
{
    number,      // [0-9]+
    indentifier, // 变量
    kw_int,      // int
    kw_if,       // if
    kw_else,     // else
    plus,        // +
    minus,       // -
    star,        // *
    slash,       // /
    equal,       // =
    l_parent,    // (
    r_parent,    // )
    semi,        // ;
    comma,       // ,
    eof          // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    const char *ptr; // for debug
    int length;

    CType *type; // for built-in type

public:
    void Dump();
    static llvm::StringRef GetSpellingText(TokenType tokenType);
};

class Lexer
{
private:
    DiagEngine &diagEngine;
    llvm::SourceMgr &mgr;

public:
    Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
    {
        unsigned id = mgr.getMainFileID();
        llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();
    void RestoreState();

    DiagEngine &GetDiagEngine() const
    {
        return diagEngine;
    }

private:
    struct State
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state;
};

lexer.cc

#include "lexer.h"

void Token::Dump()
{
    llvm::StringRef text(ptr, length);
    llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}

// number,     // [0-9]+
// indentifier,// 变量
// kw_int,     // int
// plus,       // +
// minus,      // -
// star,       // *
// slash,      // /
// equal,      // =
// l_parent,   // (
// r_parent,   // )
// semi,       // ;
// comma,      // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    case TokenType::kw_if:
        return "if";
    case TokenType::kw_else:
        return "else";
    default:
        llvm::llvm_unreachable_internal(); // 不可能到达这个位置
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
        {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
        // 对 if / else 进行判断
        else if(text == "if")
        {
            token.tokenType = TokenType::kw_if;
        }
        else if(text == "else")
        {
            token.tokenType = TokenType::kw_else;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            token.tokenType = TokenType::equal;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

1.3 语法分析器 (Parser)

根据文法定义,parser 需要新增对 if 语句的解析,以及对于语句的解析封装成为一个独立的函数。

实现代码

ast.h

#pragma once

#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"

enum class OpCode
{
    add,
    sub,
    mul,
    div
};

// 进行声明
class Program;
class Expr;
class DeclStmt;
class VariableDecl;
class IfStmt;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;

// 访问者模式
class Visitor
{
public:
    virtual ~Visitor() {}
    virtual llvm::Value *VisitProgram(Program *p) = 0;
    virtual llvm::Value *VisitDeclStmt(DeclStmt* decl) = 0;
    virtual llvm::Value *VisitIfStmt(IfStmt* decl) = 0;
    virtual llvm::Value *VisitVariableDecl(VariableDecl* decl) = 0;
    virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr* varaccExpr) = 0;
    virtual llvm::Value *VisitAssignExpr(AssignExpr* assignExpr) = 0;
    virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
    virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};

// 语法树的公共节点
class ASTNode 
{
public:
    enum Kind
    {
        ND_BlockStmt,
        ND_IfStmt,
        ND_DeclStmt,
        ND_VariableDecl,
        ND_BinaryExpr,
        ND_NumberExpr,
        ND_VariableAccessExpr,
        ND_AssignExpr
    };

private:
    const Kind kind;
public:
    ASTNode(Kind kind) : kind(kind) {}
    virtual ~ASTNode() {}
    virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
    const Kind getKind() const { return kind; }

public:
    CType *type;
    Token token;
};

// 声明语句节点
class DeclStmt : public ASTNode
{
public:
    DeclStmt() : ASTNode(ND_DeclStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitDeclStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_DeclStmt;
    }
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

// 条件判断节点
class IfStmt : public ASTNode
{
public:
    IfStmt() : ASTNode(ND_IfStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitIfStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_IfStmt;
    }
public:
    std::shared_ptr<ASTNode> condNode;
    std::shared_ptr<ASTNode> thenNode;
    std::shared_ptr<ASTNode> elseNode;
};

// 变量声明节点
class VariableDecl : public ASTNode 
{
public:
    VariableDecl() : ASTNode(ND_VariableDecl) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableDecl(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableDecl;
    }
};

// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
    BinaryExpr() : ASTNode(ND_BinaryExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBinaryExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BinaryExpr;
    }
public:
    OpCode op;
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
    AssignExpr() : ASTNode(ND_AssignExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitAssignExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_AssignExpr;
    }
public:
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
    NumberExpr() : ASTNode(ND_NumberExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitNumberExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_NumberExpr;
    }
};

// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
    VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableAccessExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableAccessExpr;
    }
};

// 目标程序
class Program
{
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

parser.h

#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::shared_ptr<ASTNode> ParseStmt();
    std::shared_ptr<ASTNode> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseIfStmt();
    std::shared_ptr<ASTNode> ParseAssignExpr();
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseTerm();
    std::shared_ptr<ASTNode> ParseFactor();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

    DiagEngine &GetDiagEngine() const
    {
        return lexer.GetDiagEngine();
    }

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
};

parser.cc

#include "parser.h"
#include <cassert>

/*
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
*/

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
    while (token.tokenType != TokenType::eof)
    {
        auto stmt = ParseStmt();
        if (stmt)
            nodeVec.push_back(stmt);
    }
    auto program = std::make_shared<Program>();
    program->nodeVec = std::move(nodeVec);
    return program;
}

// 解析语句
std::shared_ptr<ASTNode> Parser::ParseStmt()
{
    // 遇到 ; 需要进行消费 token
    // null-stmt
    if (token.tokenType == TokenType::semi)
    {
        Consume(TokenType::semi);
        return nullptr;
    }
    // decl-stmt
    else if (token.tokenType == TokenType::kw_int)
    {
        return ParseDeclStmt();
    }
    // if-stmt
    else if (token.tokenType == TokenType::kw_if)
    {
        return ParseIfStmt();
    }
    // expr-stmt
    else
    {
        return ParseExprStmt();
    }
}

// 解析声明语句
std::shared_ptr<ASTNode> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    auto declStmt = std::make_shared<DeclStmt>();

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
        declStmt->nodeVec.push_back(variableDecl);

        Token tmp = token;
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            Token opToken = token;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(tmp);
            auto assign = sema.SemaAssignExprNode(left, right, opToken); 

            declStmt->nodeVec.push_back(assign);
        }
    }

    Advance();

    return declStmt;
}

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
/*
if (a)
  b = 3;
else
    b = 4;
*/
// 解析 if 语句
std::shared_ptr<ASTNode> Parser::ParseIfStmt()
{
    Consume(TokenType::kw_if);
    Consume(TokenType::l_parent);
    auto condExpr = ParseExpr();
    Consume(TokenType::r_parent);
    auto thenStmt = ParseStmt();
    std::shared_ptr<ASTNode> elseStmt = nullptr;
    if (token.tokenType == TokenType::kw_else)
    {
        Consume(TokenType::kw_else);
        elseStmt = ParseStmt();
    }
    return sema.SemaIfStmtNode(condExpr, thenStmt, elseStmt); // 通过语义分析器完成初始化的工作
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState();
    bool isAssign = false;
    Token tmp = token;
    // a = b;
    if (tmp.tokenType == TokenType::indentifier)
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal)
        {
            isAssign = true;
        }
    }
    lexer.RestoreState();
    if (isAssign)
    {
        return ParseAssignExpr();
    }
    // add-expr
    std::shared_ptr<ASTNode> left = ParseTerm();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseTerm();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    Token tmp = token;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(tmp);
    Token opToken = token;
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
    std::shared_ptr<ASTNode> left = ParseFactor();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParseFactor();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Advance();
        auto expr = ParseExpr();
        assert(Expect(TokenType::r_parent));
        Advance();
        return expr;
    }
    else if (token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token);
        Advance();
        return variableAccessExpr;
    }
    else
    {
        Expect(TokenType::number);
        auto factorExpr = sema.SemaNumberExprNode(token, token.type);
        Advance();
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                           diag::err_expected,
                           Token::GetSpellingText(tokenType),
                           llvm::StringRef(token.ptr, token.length));
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

printVisitor.h

#pragma once

#include "ast.h"
#include "parser.h"

class PrintVisitor : public Visitor
{
public:
    PrintVisitor(std::shared_ptr<Program> program);

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitIfStmt(IfStmt* ifStmt) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};

printVisitor.cc

#include "printVisitor.h"

PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
    VisitProgram(program.get());
}

llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
    for (auto &expr : p->nodeVec)
    {
        expr->Accept(this);
        llvm::outs() << "\n";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitDeclStmt(DeclStmt *decl)
{
    for (auto node : decl->nodeVec)
    {
        node->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::outs() << "if (";
    ifStmt->condNode->Accept(this);
    llvm::outs() << ")\n";
    ifStmt->thenNode->Accept(this);
    llvm::outs() << "\n";
    if (ifStmt->elseNode)
    {
        llvm::outs() << "else\n";
        ifStmt->elseNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
    if (decl->type == CType::getIntTy())
    {
        llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length) << ";";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
    assignExpr->left->Accept(this);

    llvm::outs() << " = ";

    assignExpr->right->Accept(this);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    // 后序遍历
    binaryExpr->left->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        llvm::outs() << " + ";
        break;
    }
    case OpCode::sub:
    {
        llvm::outs() << " - ";
        break;
    }
    case OpCode::mul:
    {
        llvm::outs() << " * ";
        break;
    }
    case OpCode::div:
    {
        llvm::outs() << " / ";
        break;
    }
    default:
    {
        break;
    }
    }

    binaryExpr->right->Accept(this);

    return nullptr;
}

llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
    llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
    return nullptr;
}

1.4 语义分析器 (Sema)

对语义分析器新增对 If 语句节点初始化的函数。

实现代码

sema.h

#pragma once

#include "scope.h"
#include "ast.h"

class Sema // 语义分析
{
public:
    Sema(DiagEngine &diagEngine) : diagEngine(diagEngine) {}
    std::shared_ptr<ASTNode> SemaVariableDeclNode(Token token, CType* ty);
    std::shared_ptr<ASTNode> SemaVariableAccessNode(Token token);
    std::shared_ptr<ASTNode> SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token);
    std::shared_ptr<ASTNode> SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op);
    std::shared_ptr<ASTNode> SemaIfStmtNode(std::shared_ptr<ASTNode> condNode, std::shared_ptr<ASTNode> thenNode, std::shared_ptr<ASTNode> elseNode);
    std::shared_ptr<ASTNode> SemaNumberExprNode(Token token, CType* ty);

    DiagEngine &GetDiaEngine() const
    {
        return diagEngine;
    }
private:
    Scope scope;
    DiagEngine &diagEngine;
};

sema.cc

#include "sema.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Casting.h"

// 符号声明节点
std::shared_ptr<ASTNode> Sema::SemaVariableDeclNode(Token token, CType *ty)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbolInCurEnv(text);
    if (symbol) // 查找是否重定义
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                              diag::err_redefine,
                              llvm::StringRef(token.ptr, token.length));
    }

    scope.AddVarSymbol(SymbolKind::LocalVariable, ty, text); // 添加到符号表

    auto variableDecl = std::make_shared<VariableDecl>();
    variableDecl->token = token;
    variableDecl->type = ty;

    return variableDecl;
}

std::shared_ptr<ASTNode> Sema::SemaIfStmtNode(std::shared_ptr<ASTNode> condNode, std::shared_ptr<ASTNode> thenNode, std::shared_ptr<ASTNode> elseNode)
{
    auto ifStmt = std::make_shared<IfStmt>();
    ifStmt->condNode = condNode;
    ifStmt->thenNode = thenNode;
    ifStmt->elseNode = elseNode;

    return ifStmt;
}

std::shared_ptr<ASTNode> Sema::SemaVariableAccessNode(Token token)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbol(text);
    // auto symbol = scope.FindVarSymbolInCurEnv(name); // err
    if (symbol == nullptr)
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                              diag::err_undefine,
                              llvm::StringRef(token.ptr, token.length));
    }
    auto varAcc = std::make_shared<VariableAccessExpr>();
    varAcc->token = token;
    varAcc->type = symbol->getTy();

    return varAcc;
}

std::shared_ptr<ASTNode> Sema::SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token)
{
    assert(left && right);

    if (!llvm::isa<VariableAccessExpr>(left.get()))
    {
        diagEngine.Report(llvm::SMLoc::getFromPointer(left->token.ptr), diag::err_lvalue);
    }

    auto assign = std::make_shared<AssignExpr>();
    assign->left = left;
    assign->right = right;

    return assign;
}

std::shared_ptr<ASTNode> Sema::SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op)
{
    auto binaryExpr = std::make_shared<BinaryExpr>();
    binaryExpr->left = left;
    binaryExpr->right = right;
    binaryExpr->op = op;

    return binaryExpr;
}

std::shared_ptr<ASTNode> Sema::SemaNumberExprNode(Token token, CType *ty)
{
    auto factorExpr = std::make_shared<NumberExpr>();
    factorExpr->token = token;
    factorExpr->type = ty;

    return factorExpr;
}

1.5 代码生成 (CodeGen)

基本块划分

代码生成部分会新增几个基本块,下面是基本块的划分方式:

  • 存在 else 语句的情况

    条件块 → then块 → last块
            ↘ else块 ↗
    

    这种情况首先跳转到条件块,根据条件结果判断跳转:

    • 条件结果不等于 0:跳转到 then 块,处理完后跳转到 last 块
    • 条件结果等于 0:跳转到 else 块,处理完后跳转到 last 块
  • 不存在 else 语句的情况

    条件块 → then块 → last块
            ↘ last块 ↗
    

    这种情况首先跳转到条件块,根据条件结果判断跳转:

    • 条件结果不等于 0:跳转到 then 块,处理完后跳转到 last 块
    • 条件结果等于 0:直接跳转到 last 块

实现代码

codegen.h

#pragma once

#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"

// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
    CodeGen(std::shared_ptr<Program> program)
    {
        module = std::make_shared<llvm::Module>("expr", context);
        VisitProgram(program.get());
    }

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitIfStmt(IfStmt *decl) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;

private:
    llvm::LLVMContext context;
    llvm::IRBuilder<> irBuilder{context};
    std::shared_ptr<llvm::Module> module;
    llvm::Function *curFunc{nullptr};
    llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
};

codegen.cc

#include "codegen.h"
#include "llvm/IR/Verifier.h"

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBasicBlock = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBasicBlock);
    // 记录当前函数
    curFunc = mainFunction;

    llvm::Value *lastVal = nullptr;
    for (auto node : p->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    if (lastVal)
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
    else
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("last instruction is not expr!\n")});

    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

llvm::Value *CodeGen::VisitDeclStmt(DeclStmt *declStmt)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : declStmt->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::BasicBlock *condBasicBlock = llvm::BasicBlock::Create(context, "cond", curFunc);
    llvm::BasicBlock *thenBasicBlock = llvm::BasicBlock::Create(context, "then", curFunc);
    llvm::BasicBlock *elseBasicBlock = nullptr;
    if (ifStmt->elseNode)
        elseBasicBlock = llvm::BasicBlock::Create(context, "else", curFunc);
    llvm::BasicBlock *lastBasicBlock = llvm::BasicBlock::Create(context, "last", curFunc);

    // 需要手动添加一个无条件跳转指令,llvm 不会自动完成这个工作
    irBuilder.CreateBr(condBasicBlock);
    irBuilder.SetInsertPoint(condBasicBlock);
    llvm::Value *ret = ifStmt->condNode->Accept(this);
    // 整形比较指令
    llvm::Value *condVal = irBuilder.CreateICmpNE(ret, irBuilder.getInt32(0)); // 这里需要判断条件是否为真

    if (ifStmt->elseNode)
    {
        // 条件跳转指令
        irBuilder.CreateCondBr(condVal, thenBasicBlock, elseBasicBlock);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBasicBlock);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBasicBlock);

        // handle else basic block
        irBuilder.SetInsertPoint(elseBasicBlock);
        ifStmt->elseNode->Accept(this);
        // 无条件跳转指令
        irBuilder.CreateBr(lastBasicBlock);
    }
    else
    {
        irBuilder.CreateCondBr(condVal, thenBasicBlock, lastBasicBlock);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBasicBlock);
        ifStmt->thenNode->Accept(this);
        // 无条件跳转指令
        irBuilder.CreateBr(lastBasicBlock);
    }
    irBuilder.SetInsertPoint(lastBasicBlock);
    return nullptr;
}

llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
    llvm::Type *ty = nullptr;
    if (decl->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    llvm::StringRef text(decl->token.ptr, decl->token.length);
    llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, text);
    varAddrTyMap.insert({text, {varAddr, ty}});
    return varAddr;
}

llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::StringRef text(varaccExpr->token.ptr, varaccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *varAddr = pair.first;
    llvm::Type *ty = pair.second;
    if (varaccExpr->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    // 返回一个右值
    return irBuilder.CreateLoad(ty, varAddr, text);
}

// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
    VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
    llvm::StringRef text(varAccExpr->token.ptr, varAccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *addr = pair.first;
    llvm::Type *ty = pair.second;
    llvm::Value *rValue = assignExpr->right->Accept(this);
    // 这个得到的是一个左值
    irBuilder.CreateStore(rValue, addr);
    // 返回一个右值
    return irBuilder.CreateLoad(ty, addr, text);
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->token.value);
}

1.6 测试 if 语句功能

生成 IR 文件

bin/expr test/expr_01.txt > test/expr_01.ll
bin/expr test/expr_02.txt > test/expr_02.ll
bin/expr test/expr_03.txt > test/expr_03.ll

生成的 IR 内容

expr_01.ll

; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 3, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %0 = icmp ne i32 %a3, 0
  br i1 %0, label %then, label %else

then:                                             ; preds = %cond
  store i32 2, ptr %a, align 4
  %a4 = load i32, ptr %a, align 4
  br label %last

else:                                             ; preds = %cond
  store i32 4, ptr %a, align 4
  %a5 = load i32, ptr %a, align 4
  br label %last

last:                                             ; preds = %else, %then
  %a6 = load i32, ptr %a, align 4
  %b7 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a6, %b7
  %sub = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

expr_02.ll

; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 3, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %sub = sub nsw i32 %a3, 3
  %0 = icmp ne i32 %sub, 0
  br i1 %0, label %then, label %else

then:                                             ; preds = %cond
  store i32 2, ptr %a, align 4
  %a4 = load i32, ptr %a, align 4
  br label %last

else:                                             ; preds = %cond
  store i32 4, ptr %a, align 4
  %a5 = load i32, ptr %a, align 4
  br label %last

last:                                             ; preds = %else, %then
  %a6 = load i32, ptr %a, align 4
  %b7 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a6, %b7
  %sub8 = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub8)
  ret i32 0
}

expr_03.ll

; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 3, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %0 = icmp ne i32 %a3, 0
  br i1 %0, label %then, label %last

then:                                             ; preds = %cond
  store i32 2, ptr %a, align 4
  %a4 = load i32, ptr %a, align 4
  br label %last

last:                                             ; preds = %then, %cond
  %a5 = load i32, ptr %a, align 4
  %b6 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a5, %b6
  %sub = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

运行 IR

lli test/expr_01.ll
lli test/expr_02.ll
lli test/expr_03.ll

运行结果

expr_01.ll

expr value: 6

expr_02.ll

expr value: 16

expr_03.ll

expr value: 6

结果正确,验证了编译器目前基础的 if 语句功能的正确性。

2. 实现块语句的功能

2.1 测试文件

expr.txt

int a = 3;
int b = 5;
if (a)
{
  a = 2;
  b = 3;
}
else 
{  
  a = 4;
  b = 2;
}
a * b - 4;

2.2 词法分析器 (Lexer)

文法定义

ebnf.txt

prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt | block-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
block-stmt : "{" stmt* "}"
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)* 
mult-expr : primary-expr (("*" | "/") primary-expr)* 
primary-expr : identifier | number | "(" expr ")" 
number: ([0-9])+ 
identifier : (a-zA-Z_)(a-zA-Z0-9_)*

实现代码

词法分析器增加大括号token类型以及识别。

lexer.h

#pragma once

#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include "type.h"
#include "diag_engine.h"

// char stream -> token

enum class TokenType
{
    number,      // [0-9]+
    indentifier, // 变量
    kw_int,      // int
    kw_if,       // if
    kw_else,     // else
    plus,        // +
    minus,       // -
    star,        // *
    slash,       // /
    equal,       // =
    l_parent,    // (
    r_parent,    // )
    l_brace,     // {
    r_brace,     // }
    semi,        // ;
    comma,       // ,
    eof          // end of file
};

class Token
{
public:
    TokenType tokenType; // token 的种类
    int row, col;

    int value; // for number

    const char *ptr; // for debug
    int length;

    CType *type; // for built-in type

public:
    void Dump();
    static llvm::StringRef GetSpellingText(TokenType tokenType);
};

class Lexer
{
private:
    DiagEngine &diagEngine;
    llvm::SourceMgr &mgr;

public:
    Lexer(DiagEngine &diagEngine, llvm::SourceMgr &mgr) : diagEngine(diagEngine), mgr(mgr)
    {
        unsigned id = mgr.getMainFileID();
        llvm::StringRef buf = mgr.getMemoryBuffer(id)->getBuffer();
        BufPtr = buf.begin();
        LineHeadPtr = buf.begin();
        BufEnd = buf.end();
        row = 1;
    }
    void NextToken(Token &token);

    void SaveState();
    void RestoreState();

    DiagEngine &GetDiagEngine() const
    {
        return diagEngine;
    }

private:
    struct State
    {
        const char *BufPtr;
        const char *LineHeadPtr;
        const char *BufEnd;
        int row;
    };

private:
    const char *BufPtr;
    const char *LineHeadPtr;
    const char *BufEnd;
    int row;

    State state;
};

lexer.cc

#include "lexer.h"

void Token::Dump()
{
    llvm::StringRef text(ptr, length);
    llvm::outs() << "{" << text << ", row = " << row << ", col = " << col << "}\n";
}

// number,     // [0-9]+
// indentifier,// 变量
// kw_int,     // int
// plus,       // +
// minus,      // -
// star,       // *
// slash,      // /
// equal,      // =
// l_parent,   // (
// r_parent,   // )
// semi,       // ;
// comma,      // ,
llvm::StringRef Token::GetSpellingText(TokenType tokenType)
{
    switch (tokenType)
    {
    case TokenType::kw_int:
        return "int";
    case TokenType::plus:
        return "+";
    case TokenType::minus:
        return "-";
    case TokenType::star:
        return "*";
    case TokenType::slash:
        return "/";
    case TokenType::equal:
        return "=";
    case TokenType::l_parent:
        return "(";
    case TokenType::r_parent:
        return ")";
    case TokenType::semi:
        return ";";
    case TokenType::comma:
        return ",";
    case TokenType::number:
        return "number";
    case TokenType::indentifier:
        return "indentifier";
    case TokenType::kw_if:
        return "if";
    case TokenType::kw_else:
        return "else";
    case TokenType::l_brace:
        return "{";
    case TokenType::r_brace:
        return "}";
    default:
        llvm::llvm_unreachable_internal(); // 不可能到达这个位置
    }
}

bool IsWhiteSpace(char ch)
{
    return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n';
}

bool IsDigit(char ch)
{
    return ch >= '0' && ch <= '9';
}

bool IsLetter(char ch)
{
    // a-z, A-Z, _
    return (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch == '_');
}

void Lexer::NextToken(Token &token)
{
    // 过滤空格
    while (IsWhiteSpace(*BufPtr))
    {
        if (*BufPtr == '\n')
        {
            row += 1;
            LineHeadPtr = BufPtr + 1;
        }
        BufPtr++;
    }

    token.row = row;
    token.col = BufPtr - LineHeadPtr + 1;

    // 判断是否到结尾了
    if (BufPtr >= BufEnd)
    {
        token.tokenType = TokenType::eof;
        return;
    }

    token.ptr = BufPtr;
    token.length = 0;
    // 判断是否为数字
    if (IsDigit(*BufPtr))
    {
        int len = 0;
        int val = 0;
        while (IsDigit(*BufPtr))
        {
            val = val * 10 + *BufPtr++ - '0';
            token.length++;
        }
        token.value = val;
        token.tokenType = TokenType::number;
        token.type = CType::getIntTy();
    }
    else if (IsLetter(*BufPtr)) // 为变量
    {
        while (IsLetter(*BufPtr) || IsDigit(*BufPtr))
        {
            BufPtr++;
        }
        token.tokenType = TokenType::indentifier;
        token.length = BufPtr - token.ptr;
        llvm::StringRef text(token.ptr, BufPtr - token.ptr);
        if (text == "int")
        {
            token.tokenType = TokenType::kw_int;
        }
        else if(text == "if")
        {
            token.tokenType = TokenType::kw_if;
        }
        else if(text == "else")
        {
            token.tokenType = TokenType::kw_else;
        }
    }
    else // 为特殊字符
    {
        switch (*BufPtr)
        {
        case '+':
        {
            token.tokenType = TokenType::plus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '-':
        {
            token.tokenType = TokenType::minus;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '*':
        {
            token.tokenType = TokenType::star;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '/':
        {
            token.tokenType = TokenType::slash;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '=':
        {
            token.tokenType = TokenType::equal;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '(':
        {
            token.tokenType = TokenType::l_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ')':
        {
            token.tokenType = TokenType::r_parent;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ';':
        {
            token.tokenType = TokenType::semi;
            token.length = 1;
            BufPtr++;
            break;
        }
        case ',':
        {
            token.tokenType = TokenType::comma;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '{':
        {
            token.tokenType = TokenType::l_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        case '}':
        {
            token.tokenType = TokenType::r_brace;
            token.length = 1;
            BufPtr++;
            break;
        }
        default:
        {
            diagEngine.Report(llvm::SMLoc::getFromPointer(BufPtr), diag::err_unknown_char, *BufPtr);
        }
        }
    }
}

void Lexer::SaveState()
{
    state.LineHeadPtr = LineHeadPtr;
    state.BufPtr = BufPtr;
    state.BufEnd = BufEnd;
    state.row = row;
}

void Lexer::RestoreState()
{
    LineHeadPtr = state.LineHeadPtr;
    BufPtr = state.BufPtr;
    BufEnd = state.BufEnd;
    row = state.row;
}

2.3 语法分析器 (Parser)

实现代码

抽象语法树 AST 中增加块语句节点,语法分析器增加块语句解析。

ast.h

#pragma once

#include <vector>
#include <memory>
#include "llvm/IR/Value.h"
#include "type.h"
#include "lexer.h"

enum class OpCode
{
    add,
    sub,
    mul,
    div
};

// 进行声明
class Program;
class Expr;
class DeclStmt;
class BlockStmt;
class VariableDecl;
class IfStmt;
class VariableAccessExpr;
class BinaryExpr;
class AssignExpr;
class NumberExpr;

// 访问者模式
class Visitor
{
public:
    virtual ~Visitor() {}
    virtual llvm::Value *VisitProgram(Program *p) = 0;
    virtual llvm::Value *VisitDeclStmt(DeclStmt* decl) = 0;
    virtual llvm::Value *VisitBlockStmt(BlockStmt* block) = 0;
    virtual llvm::Value *VisitIfStmt(IfStmt* decl) = 0;
    virtual llvm::Value *VisitVariableDecl(VariableDecl* decl) = 0;
    virtual llvm::Value *VisitVariableAccessExpr(VariableAccessExpr* varaccExpr) = 0;
    virtual llvm::Value *VisitAssignExpr(AssignExpr* assignExpr) = 0;
    virtual llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) = 0;
    virtual llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) = 0;
};

// 语法树的公共节点
class ASTNode 
{
public:
    enum Kind
    {
        ND_BlockStmt,
        ND_IfStmt,
        ND_DeclStmt,
        ND_VariableDecl,
        ND_BinaryExpr,
        ND_NumberExpr,
        ND_VariableAccessExpr,
        ND_AssignExpr
    };

private:
    const Kind kind;
public:
    ASTNode(Kind kind) : kind(kind) {}
    virtual ~ASTNode() {}
    virtual llvm::Value *Accept(Visitor *v) { return nullptr; } // 通过虚函数的特性完成分发的功能
    const Kind getKind() const { return kind; }

public:
    CType *type;
    Token token;
};

// 声明语句节点
class DeclStmt : public ASTNode
{
public:
    DeclStmt() : ASTNode(ND_DeclStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitDeclStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_DeclStmt;
    }
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

// 块语句节点
class BlockStmt : public ASTNode
{
public:
    BlockStmt() : ASTNode(ND_BlockStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBlockStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BlockStmt;
    }
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

// 条件判断节点
class IfStmt : public ASTNode
{
public:
    IfStmt() : ASTNode(ND_IfStmt) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitIfStmt(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_IfStmt;
    }
public:
    std::shared_ptr<ASTNode> condNode;
    std::shared_ptr<ASTNode> thenNode;
    std::shared_ptr<ASTNode> elseNode;
};

// 变量声明节点
class VariableDecl : public ASTNode 
{
public:
    VariableDecl() : ASTNode(ND_VariableDecl) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableDecl(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableDecl;
    }
};

// 二元表达式节点
class BinaryExpr : public ASTNode
{
public:
    BinaryExpr() : ASTNode(ND_BinaryExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitBinaryExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_BinaryExpr;
    }
public:
    OpCode op;
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 赋值表达式节点
class AssignExpr : public ASTNode
{
public:
    AssignExpr() : ASTNode(ND_AssignExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitAssignExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_AssignExpr;
    }
public:
    std::shared_ptr<ASTNode> left;
    std::shared_ptr<ASTNode> right;
};

// 数字表达式节点
class NumberExpr : public ASTNode
{
public:
    NumberExpr() : ASTNode(ND_NumberExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitNumberExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_NumberExpr;
    }
};

// 变量访问节点
class VariableAccessExpr : public ASTNode // 变量表达式
{
public:
    VariableAccessExpr() : ASTNode(ND_VariableAccessExpr) {}

    llvm::Value *Accept(Visitor *v) override
    {
        return v->VisitVariableAccessExpr(this);
    }

    static bool classof(const ASTNode *node) // 判断是否能进行指针的强转
    {
        return node->getKind() == ND_VariableAccessExpr;
    }
};

// 目标程序
class Program
{
public:
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
};

parser.h

#pragma once

#include "ast.h"
#include "lexer.h"
#include "sema.h"

class Parser
{
public:
    Parser(Lexer &lexer, Sema &sema) : lexer(lexer), sema(sema)
    {
        Advance();
    }
    std::shared_ptr<Program> ParseProgram();

private:
    std::shared_ptr<ASTNode> ParseStmt();
    std::shared_ptr<ASTNode> ParseDeclStmt();
    std::shared_ptr<ASTNode> ParseBlockStmt();
    std::shared_ptr<ASTNode> ParseExprStmt();
    std::shared_ptr<ASTNode> ParseIfStmt();
    std::shared_ptr<ASTNode> ParseAssignExpr();
    std::shared_ptr<ASTNode> ParseExpr();
    std::shared_ptr<ASTNode> ParseTerm();
    std::shared_ptr<ASTNode> ParseFactor();

    // 消耗 token 的函数
    // 检测 token 的类型
    bool Expect(TokenType tokenType);
    // 检测 token 的类型并消费
    bool Consume(TokenType tokenType);
    // 直接消耗当前的 token
    bool Advance();

    DiagEngine &GetDiagEngine() const
    {
        return lexer.GetDiagEngine();
    }

private:
    Lexer &lexer;
    Sema &sema;
    Token token;
};

parser.cc

#include "parser.h"
#include <cassert>

/*
prog : stmt*
stmt : decl-stmt | expr-stmt | null-stmt | if-stmt
null-stmt : ";"
decl-stmt : "int" identifier ("," identifier (= expr)?)* ";"
if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
expr-stmt : expr ";"
expr : assign-expr | add-expr
assign-expr: identifier "=" expr
add-expr : mult-expr (("+" | "-") mult-expr)*
mult-expr : primary-expr (("*" | "/") primary-expr)*
primary-expr : identifier | number | "(" expr ")"
number: ([0-9])+
identifier : (a-zA-Z_)(a-zA-Z0-9_)*
*/

// 解析目标程序
// stmt : decl-stmt | expr-stmt | null-stmt
std::shared_ptr<Program> Parser::ParseProgram()
{
    std::vector<std::shared_ptr<ASTNode>> nodeVec;
    while (token.tokenType != TokenType::eof)
    {
        auto stmt = ParseStmt();
        if (stmt)
            nodeVec.push_back(stmt);
    }
    auto program = std::make_shared<Program>();
    program->nodeVec = std::move(nodeVec);
    return program;
}

// 解析语句
std::shared_ptr<ASTNode> Parser::ParseStmt()
{
    // 遇到 ; 需要进行消费 token
    // null-stmt
    if (token.tokenType == TokenType::semi)
    {
        Consume(TokenType::semi);
        return nullptr;
    }
    // decl-stmt
    else if (token.tokenType == TokenType::kw_int)
    {
        return ParseDeclStmt();
    }
    // block-stmt
    else if (token.tokenType == TokenType::l_brace)
    {
        return ParseBlockStmt();
    }
    // if-stmt
    else if (token.tokenType == TokenType::kw_if)
    {
        return ParseIfStmt();
    }
    // expr-stmt
    else
    {
        return ParseExprStmt();
    }
}

// 解析声明语句
std::shared_ptr<ASTNode> Parser::ParseDeclStmt()
{
    /// int a, b = 3;
    /// int a = 3;
    Consume(TokenType::kw_int);
    CType *baseTy = CType::getIntTy();
    /// a , b = 3;
    /// a = 3;

    auto declStmt = std::make_shared<DeclStmt>();

    /// a, b = 3;
    /// a = 3;
    int i = 0;
    while (token.tokenType != TokenType::semi)
    {
        if (i++ > 0) // if (i++)
        {
            assert(Consume(TokenType::comma));
        }

        /// 变量声明的节点: int a = 3; -> int a; a = 3;
        // a = 3;
        auto variableDecl = sema.SemaVariableDeclNode(token, baseTy); // get a type
        declStmt->nodeVec.push_back(variableDecl);

        Token tmp = token;
        Consume(TokenType::indentifier);

        // = 3;
        if (token.tokenType == TokenType::equal)
        {
            Token opToken = token;
            Advance();

            // 3;
            auto right = ParseExpr();
            auto left = sema.SemaVariableAccessNode(tmp);
            auto assign = sema.SemaAssignExprNode(left, right, opToken);

            declStmt->nodeVec.push_back(assign);
        }
    }

    Advance();

    return declStmt;
}

std::shared_ptr<ASTNode> Parser::ParseBlockStmt()
{
    sema.EnterScope(); // 进入作用域

    auto blockStmt = std::make_shared<BlockStmt>();
    Consume(TokenType::l_brace);
    while(token.tokenType != TokenType::r_brace)
    {
        auto stmt = ParseStmt();
        if(stmt)
            blockStmt->nodeVec.push_back(stmt);
    }
    Consume(TokenType::r_brace);

    sema.ExitScope(); // 离开作用域
    
    return blockStmt;
}   

// 解析表达式语句
std::shared_ptr<ASTNode> Parser::ParseExprStmt()
{
    auto expr = ParseExpr();
    Consume(TokenType::semi);
    return expr;
}

// if-stmt : "if" "(" expr ")" stmt ( "else" stmt )?
/*
if (a)
  b = 3;
else
    b = 4;
*/
std::shared_ptr<ASTNode> Parser::ParseIfStmt()
{
    Consume(TokenType::kw_if);
    Consume(TokenType::l_parent);
    auto condExpr = ParseExpr();
    Consume(TokenType::r_parent);
    auto thenStmt = ParseStmt();
    std::shared_ptr<ASTNode> elseStmt = nullptr;
    if (token.tokenType == TokenType::kw_else)
    {
        Consume(TokenType::kw_else);
        elseStmt = ParseStmt();
    }
    return sema.SemaIfStmtNode(condExpr, thenStmt, elseStmt);
}

// 解析表达式
// expr : assign-expr | add-expr
// assign-expr: identifier "=" expr
// add-expr : mult-expr (("+" | "-") mult-expr)*
std::shared_ptr<ASTNode> Parser::ParseExpr()
{
    lexer.SaveState();
    bool isAssign = false;
    Token tmp = token;
    // a = b;
    if (tmp.tokenType == TokenType::indentifier)
    {
        lexer.NextToken(tmp);
        if (tmp.tokenType == TokenType::equal)
        {
            isAssign = true;
        }
    }
    lexer.RestoreState();
    if (isAssign)
    {
        return ParseAssignExpr();
    }
    // add-expr
    std::shared_ptr<ASTNode> left = ParseTerm();
    while (token.tokenType == TokenType::plus || token.tokenType == TokenType::minus)
    {
        OpCode op;
        if (token.tokenType == TokenType::plus)
        {
            op = OpCode::add;
        }
        else
        {
            op = OpCode::sub;
        }
        Advance();
        auto right = ParseTerm();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析赋值表达式
std::shared_ptr<ASTNode> Parser::ParseAssignExpr()
{
    // a = b;
    Token tmp = token;
    Consume(TokenType::indentifier);
    auto expr = sema.SemaVariableAccessNode(tmp);
    Token opToken = token;
    Consume(TokenType::equal);
    return sema.SemaAssignExprNode(expr, ParseExpr(), opToken);
}

// 解析项
std::shared_ptr<ASTNode> Parser::ParseTerm()
{
    std::shared_ptr<ASTNode> left = ParseFactor();

    while (token.tokenType == TokenType::star || token.tokenType == TokenType::slash)
    {
        OpCode op;
        if (token.tokenType == TokenType::star)
        {
            op = OpCode::mul;
        }
        else
        {
            op = OpCode::div;
        }
        Advance();
        auto right = ParseFactor();
        auto binaryExpr = sema.SemaBinaryExprNode(left, right, op);

        left = binaryExpr;
    }
    return left;
}

// 解析因子
std::shared_ptr<ASTNode> Parser::ParseFactor()
{
    if (token.tokenType == TokenType::l_parent)
    {
        Advance();
        auto expr = ParseExpr();
        assert(Expect(TokenType::r_parent));
        Advance();
        return expr;
    }
    else if (token.tokenType == TokenType::indentifier)
    {
        auto variableAccessExpr = sema.SemaVariableAccessNode(token);
        Advance();
        return variableAccessExpr;
    }
    else
    {
        Expect(TokenType::number);
        auto factorExpr = sema.SemaNumberExprNode(token, token.type);
        Advance();
        return factorExpr;
    }
}

/// 消耗 token 函数
bool Parser::Expect(TokenType tokenType)
{
    if (token.tokenType == tokenType)
        return true;
    GetDiagEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                           diag::err_expected,
                           Token::GetSpellingText(tokenType),
                           llvm::StringRef(token.ptr, token.length));
    return false;
}

bool Parser::Consume(TokenType tokenType)
{
    if (Expect(tokenType))
    {
        Advance();
        return true;
    }
    return false;
}

bool Parser::Advance()
{
    lexer.NextToken(token);
    return true;
}

printVisitor.h

#pragma once

#include "ast.h"
#include "parser.h"

class PrintVisitor : public Visitor
{
public:
    PrintVisitor(std::shared_ptr<Program> program);

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitBlockStmt(BlockStmt* block) override;
    llvm::Value *VisitIfStmt(IfStmt* ifStmt) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;
};

printVisitor.cc

#include "printVisitor.h"

PrintVisitor::PrintVisitor(std::shared_ptr<Program> program)
{
    VisitProgram(program.get());
}

llvm::Value *PrintVisitor::VisitProgram(Program *p)
{
    for (auto &expr : p->nodeVec)
    {
        expr->Accept(this);
        llvm::outs() << "\n";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitDeclStmt(DeclStmt *decl)
{
    for (auto node : decl->nodeVec)
    {
        node->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBlockStmt(BlockStmt *block)
{
    llvm::outs() << "{\n";
    for(const auto &node : block->nodeVec)
    {
        node->Accept(this);
        llvm::outs() << "\n";
    }
    llvm::outs() << "}\n";
    return nullptr;
}

llvm::Value *PrintVisitor::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::outs() << "if (";
    ifStmt->condNode->Accept(this);
    llvm::outs() << ")\n";
    ifStmt->thenNode->Accept(this);
    llvm::outs() << "\n";
    if (ifStmt->elseNode)
    {
        llvm::outs() << "else\n";
        ifStmt->elseNode->Accept(this);
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableDecl(VariableDecl *decl)
{
    if (decl->type == CType::getIntTy())
    {
        llvm::outs() << "int " << llvm::StringRef(decl->token.ptr, decl->token.length) << ";";
    }
    return nullptr;
}

llvm::Value *PrintVisitor::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::outs() << llvm::StringRef(varaccExpr->token.ptr, varaccExpr->token.length);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitAssignExpr(AssignExpr *assignExpr)
{
    assignExpr->left->Accept(this);

    llvm::outs() << " = ";

    assignExpr->right->Accept(this);
    return nullptr;
}

llvm::Value *PrintVisitor::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    // 后序遍历
    binaryExpr->left->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        llvm::outs() << " + ";
        break;
    }
    case OpCode::sub:
    {
        llvm::outs() << " - ";
        break;
    }
    case OpCode::mul:
    {
        llvm::outs() << " * ";
        break;
    }
    case OpCode::div:
    {
        llvm::outs() << " / ";
        break;
    }
    default:
    {
        break;
    }
    }

    binaryExpr->right->Accept(this);

    return nullptr;
}

llvm::Value *PrintVisitor::VisitNumberExpr(NumberExpr *factorExpr)
{
    llvm::outs() << llvm::StringRef(factorExpr->token.ptr, factorExpr->token.length);
    return nullptr;
}

2.4 语义分析器 (Sema)

语义分析器增加作用域管理接口。

实现代码

sema.h

#pragma once

#include "scope.h"
#include "ast.h"

class Sema // 语义分析
{
public:
    Sema(DiagEngine &diagEngine) : diagEngine(diagEngine) {}
    std::shared_ptr<ASTNode> SemaVariableDeclNode(Token token, CType* ty);
    std::shared_ptr<ASTNode> SemaVariableAccessNode(Token token);
    std::shared_ptr<ASTNode> SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token);
    std::shared_ptr<ASTNode> SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op);
    std::shared_ptr<ASTNode> SemaIfStmtNode(std::shared_ptr<ASTNode> condNode, std::shared_ptr<ASTNode> thenNode, std::shared_ptr<ASTNode> elseNode);
    std::shared_ptr<ASTNode> SemaNumberExprNode(Token token, CType* ty);

    void EnterScope();
    void ExitScope();
    DiagEngine &GetDiaEngine() const
    {
        return diagEngine;
    }
private:
    Scope scope;
    DiagEngine &diagEngine;
};

sema.cc

#include "sema.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/Casting.h"

// 符号声明节点
std::shared_ptr<ASTNode> Sema::SemaVariableDeclNode(Token token, CType *ty)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbolInCurEnv(text);
    if (symbol) // 查找是否重定义
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                              diag::err_redefine,
                              llvm::StringRef(token.ptr, token.length));
    }

    scope.AddVarSymbol(SymbolKind::LocalVariable, ty, text); // 添加到符号表

    auto variableDecl = std::make_shared<VariableDecl>();
    variableDecl->token = token;
    variableDecl->type = ty;

    return variableDecl;
}

std::shared_ptr<ASTNode> Sema::SemaIfStmtNode(std::shared_ptr<ASTNode> condNode, std::shared_ptr<ASTNode> thenNode, std::shared_ptr<ASTNode> elseNode)
{
    auto ifStmt = std::make_shared<IfStmt>();
    ifStmt->condNode = condNode;
    ifStmt->thenNode = thenNode;
    ifStmt->elseNode = elseNode;

    return ifStmt;
}

std::shared_ptr<ASTNode> Sema::SemaVariableAccessNode(Token token)
{
    llvm::StringRef text(token.ptr, token.length);
    auto symbol = scope.FindVarSymbol(text);
    // auto symbol = scope.FindVarSymbolInCurEnv(name); // err
    if (symbol == nullptr)
    {
        GetDiaEngine().Report(llvm::SMLoc::getFromPointer(token.ptr),
                              diag::err_undefine,
                              llvm::StringRef(token.ptr, token.length));
    }
    auto varAcc = std::make_shared<VariableAccessExpr>();
    varAcc->token = token;
    varAcc->type = symbol->getTy();

    return varAcc;
}

std::shared_ptr<ASTNode> Sema::SemaAssignExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, Token token)
{
    assert(left && right);

    if (!llvm::isa<VariableAccessExpr>(left.get()))
    {
        diagEngine.Report(llvm::SMLoc::getFromPointer(left->token.ptr), diag::err_lvalue);
    }

    auto assign = std::make_shared<AssignExpr>();
    assign->left = left;
    assign->right = right;

    return assign;
}

std::shared_ptr<ASTNode> Sema::SemaBinaryExprNode(std::shared_ptr<ASTNode> left, std::shared_ptr<ASTNode> right, OpCode op)
{
    auto binaryExpr = std::make_shared<BinaryExpr>();
    binaryExpr->left = left;
    binaryExpr->right = right;
    binaryExpr->op = op;

    return binaryExpr;
}

std::shared_ptr<ASTNode> Sema::SemaNumberExprNode(Token token, CType *ty)
{
    auto factorExpr = std::make_shared<NumberExpr>();
    factorExpr->token = token;
    factorExpr->type = ty;

    return factorExpr;
}

void Sema::EnterScope()
{
    scope.EnterScope();
}
void Sema::ExitScope()
{
    scope.ExitScope();
}

2.5 代码生成 (CodeGen)

实现代码

codegen.h

#pragma once

#include "ast.h"
#include "parser.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/ADT/StringMap.h"

// 通过访问者模式生成代码
class CodeGen : public Visitor
{
public:
    CodeGen(std::shared_ptr<Program> program)
    {
        module = std::make_shared<llvm::Module>("expr", context);
        VisitProgram(program.get());
    }

public:
    llvm::Value *VisitProgram(Program *p) override;
    llvm::Value *VisitDeclStmt(DeclStmt *decl) override;
    llvm::Value *VisitBlockStmt(BlockStmt* block) override;
    llvm::Value *VisitIfStmt(IfStmt *decl) override;
    llvm::Value *VisitVariableDecl(VariableDecl *decl) override;
    llvm::Value *VisitVariableAccessExpr(VariableAccessExpr *varaccExpr) override;
    llvm::Value *VisitAssignExpr(AssignExpr *assignExpr) override;
    llvm::Value *VisitBinaryExpr(BinaryExpr *binaryExpr) override;
    llvm::Value *VisitNumberExpr(NumberExpr *factorExpr) override;

private:
    llvm::LLVMContext context;
    llvm::IRBuilder<> irBuilder{context};
    std::shared_ptr<llvm::Module> module;
    llvm::Function *curFunc{nullptr};
    llvm::StringMap<std::pair<llvm::Value *, llvm::Type *>> varAddrTyMap;
};

codegen.cc

#include "codegen.h"
#include "llvm/IR/Verifier.h"

/*
; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 0, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %0 = icmp ne i32 %a3, 0
  br i1 %0, label %then, label %else

then:                                             ; preds = %cond
  store i32 3, ptr %b, align 4
  %b4 = load i32, ptr %b, align 4
  br label %last

else:                                             ; preds = %cond
  store i32 4, ptr %a, align 4
  %a5 = load i32, ptr %a, align 4
  br label %last

last:                                             ; preds = %else, %then
  %a6 = load i32, ptr %a, align 4
  %b7 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a6, %b7
  %sub = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

*/

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get());
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBasicBlock = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBasicBlock);
    // 记录当前函数
    curFunc = mainFunction;

    llvm::Value *lastVal = nullptr;
    for (auto node : p->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    if (lastVal)
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), lastVal});
    else
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("last instruction is not expr!\n")});

    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

llvm::Value *CodeGen::VisitDeclStmt(DeclStmt *declStmt)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : declStmt->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitBlockStmt(BlockStmt *block)
{
    llvm::Value *lastVal = nullptr;
    for (auto node : block->nodeVec)
    {
        lastVal = node->Accept(this);
    }
    return lastVal;
}

llvm::Value *CodeGen::VisitIfStmt(IfStmt *ifStmt)
{
    llvm::BasicBlock *condBasicBlock = llvm::BasicBlock::Create(context, "cond", curFunc);
    llvm::BasicBlock *thenBasicBlock = llvm::BasicBlock::Create(context, "then", curFunc);
    llvm::BasicBlock *elseBasicBlock = nullptr;
    if (ifStmt->elseNode)
        elseBasicBlock = llvm::BasicBlock::Create(context, "else", curFunc);
    llvm::BasicBlock *lastBasicBlock = llvm::BasicBlock::Create(context, "last", curFunc);

    // 需要手动添加一个无条件跳转指令,llvm 不会自动完成这个工作
    irBuilder.CreateBr(condBasicBlock);
    irBuilder.SetInsertPoint(condBasicBlock);
    llvm::Value *ret = ifStmt->condNode->Accept(this);
    // 整形比较指令
    llvm::Value *condVal = irBuilder.CreateICmpNE(ret, irBuilder.getInt32(0)); // 这里需要判断条件是否为真

    if (ifStmt->elseNode)
    {
        irBuilder.CreateCondBr(condVal, thenBasicBlock, elseBasicBlock);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBasicBlock);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBasicBlock);

        // handle else basic block
        irBuilder.SetInsertPoint(elseBasicBlock);
        ifStmt->elseNode->Accept(this);
        irBuilder.CreateBr(lastBasicBlock);
    }
    else
    {
        irBuilder.CreateCondBr(condVal, thenBasicBlock, lastBasicBlock);

        // handle then basic block
        irBuilder.SetInsertPoint(thenBasicBlock);
        ifStmt->thenNode->Accept(this);
        irBuilder.CreateBr(lastBasicBlock);
    }
    irBuilder.SetInsertPoint(lastBasicBlock);
    return nullptr;
}

llvm::Value *CodeGen::VisitVariableDecl(VariableDecl *decl)
{
    llvm::Type *ty = nullptr;
    if (decl->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    llvm::StringRef text(decl->token.ptr, decl->token.length);
    llvm::Value *varAddr = irBuilder.CreateAlloca(ty, nullptr, text);
    varAddrTyMap.insert({text, {varAddr, ty}});
    return varAddr;
}

llvm::Value *CodeGen::VisitVariableAccessExpr(VariableAccessExpr *varaccExpr)
{
    llvm::StringRef text(varaccExpr->token.ptr, varaccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *varAddr = pair.first;
    llvm::Type *ty = pair.second;
    if (varaccExpr->type == CType::getIntTy())
    {
        ty = irBuilder.getInt32Ty();
    }
    // 返回一个右值
    return irBuilder.CreateLoad(ty, varAddr, text);
}

// a = 3; // right value
llvm::Value *CodeGen::VisitAssignExpr(AssignExpr *assignExpr)
{
    VariableAccessExpr *varAccExpr = (VariableAccessExpr *)assignExpr->left.get();
    llvm::StringRef text(varAccExpr->token.ptr, varAccExpr->token.length);
    std::pair pair = varAddrTyMap[text];
    llvm::Value *addr = pair.first;
    llvm::Type *ty = pair.second;
    llvm::Value *rValue = assignExpr->right->Accept(this);
    // 这个得到的是一个左值
    irBuilder.CreateStore(rValue, addr);
    // 返回一个右值
    return irBuilder.CreateLoad(ty, addr, text);
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitNumberExpr(NumberExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->token.value);
}

2.6 测试块语句功能

生成 IR 文件

bin/expr test/expr.txt > test/expr.ll

生成的 IR 内容

; ModuleID = 'expr'
source_filename = "expr"

@0 = private unnamed_addr constant [16 x i8] c"expr value: %d\0A\00", align 1

declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %a = alloca i32, align 4
  store i32 3, ptr %a, align 4
  %a1 = load i32, ptr %a, align 4
  %b = alloca i32, align 4
  store i32 5, ptr %b, align 4
  %b2 = load i32, ptr %b, align 4
  br label %cond

cond:                                             ; preds = %entry
  %a3 = load i32, ptr %a, align 4
  %0 = icmp ne i32 %a3, 0
  br i1 %0, label %then, label %else

then:                                             ; preds = %cond
  store i32 2, ptr %a, align 4
  %a4 = load i32, ptr %a, align 4
  store i32 3, ptr %b, align 4
  %b5 = load i32, ptr %b, align 4
  br label %last

else:                                             ; preds = %cond
  store i32 4, ptr %a, align 4
  %a6 = load i32, ptr %a, align 4
  store i32 2, ptr %b, align 4
  %b7 = load i32, ptr %b, align 4
  br label %last

last:                                             ; preds = %else, %then
  %a8 = load i32, ptr %a, align 4
  %b9 = load i32, ptr %b, align 4
  %mul = mul nsw i32 %a8, %b9
  %sub = sub nsw i32 %mul, 4
  %1 = call i32 (ptr, ...) @printf(ptr @0, i32 %sub)
  ret i32 0
}

运行 IR

lli test/expr.ll

运行结果

expr value: 2

3. 总结

通过本文的实现,我们成功为编译器添加了if语句和块语句的支持:

实现的功能特性

  1. 基础if语句:支持带else和不带else的两种形式
  2. 块语句:使用花括号{}组织的语句块
  3. 作用域管理:块语句自带作用域,支持变量隔离
  4. 控制流:正确的基本块划分和跳转逻辑

技术要点

  • 词法分析:新增ifelse{}等关键字的识别
  • 语法分析:完善语句解析,支持块语句嵌套
  • 语义分析:实现作用域的进入和退出管理
  • 代码生成:正确生成基本块和控制流指令

测试验证

通过多个测试用例验证了实现的正确性:

  • 条件为真时的执行路径
  • 条件为假时的执行路径
  • 块语句内的变量作用域隔离
  • 复杂的嵌套控制结构