ch07 : 计算图中表达式语法解析

110 阅读2分钟

项目地址
课程地址
个人完成的作业地址
作者原文

ch06中我们从pnnx计算图转化构建了自己的计算图。尽管我们可以获得每个op的类型和名字,但有一类op尽管我们拿到了但上节没法处理,那就是表达式。
此节我们对表达式类型的算子进行词法解析和语法解析。
如图所示:expr: add(@0,@1) image.png

表达式介绍

pnnx生成的抽象表达式为

add(@0,@1)
add(@0,mul(@1,@2))
add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)

即每个运算符跟两个运算数字,其中运算数可以为@数字或者一个完整的表达式mul(@1,@2)

词法解析

Token类型

enum class TokenType{
    TokenUnknown = -1,
    TokenInputNumber = 0,
    TokenComma = 1,
    TokenAdd = 2,
    TokenMul = 3,
    TokenLeftBracket = 4,
    TokenRightBracket = 5,
    TokenDiv = 6,
};

词与节点

// 词的类型与位置
struct Token{
    TokenType token_type = TokenType::TokenUnknown;
    int32_t start_pos = 0; // 词语出现位置
    int32_t end_pos = 0; // 词语结束的位置
    Token(TokenType token_type, int32_t start_pos, int32_t end_pos):token_type(token_type), start_pos(start_pos), end_pos(end_pos) {}
};
// 节点的值/类型(num_index为正值, 该节点为数字类型, 值为num_index; 
// 小于0,则为运算符, 具体类型为-num_index) 
// 叶子节点
struct TokenNode{
    int32_t num_index = -1;
    std::shared_ptr<TokenNode> left = nullptr;
    std::shared_ptr<TokenNode> right = nullptr;
    TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);
    TokenNode() = default;
};

词法解析,一个存词,一个存字符串

void Tokenizer(bool need_retoken=false);
std::vector<Token> tokens_;
std::vector<std::string> token_strs_;

具体解析流程

// 首先是判断statement_是否为空, 随后删除表达式中的所有空格和制表符.
CHECK(!statement_.empty()) << "The input statement is empty!";
statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) {
    return std::isspace(c);
}), statement_.end());   //?
CHECK(!statement_.empty()) << "The input statement is empty!";

// 解析add、div、mul的思路是一样的
if(c=='a'){
    CHECK(i+1<statement_.size()&&statement_.at(i+1)=='d')<< "Parse add token failed, illegal character: " << c;
    CHECK(i+2<statement_.size()&&statement_.at(i+1)=='d')<< "Parse add token failed, illegal character: " << c;
    Token token(TokenType::TokenAdd, i, i+3);
    tokens_.push_back(token);
    std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
    token_strs_.push_back(token_operation);
    i = i + 3;
}
// 解析数字类型
else if(c=='@'){
    CHECK(i+1<statement_.size()&&std::isdigit(statement_.at(i+1)))<< "Parse number token failed, illegal character: " << c;
    int32_t j = i+1;
    for(;j<statement_.size();++j){
        if(!std::isdigit(statement_.at(j))){
            break;
        }
    }
    Token token(TokenType::TokenInputNumber, i, j);
    CHECK(token.start_pos < token.end_pos);
    tokens_.push_back(token);
    std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);
    token_strs_.push_back(token_input_number);
    i = j;
}
// 解析 , ( )
else if (c == ',') {
    Token token(TokenType::TokenComma, i, i + 1);
    tokens_.push_back(token);
    std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1);
    token_strs_.push_back(token_comma);
    i += 1;
}

语法解析

递归方法
Generate_函数

// Generate_遇到的第一个token不是数字就是运算符,不能是标点符号
CHECK(current_token.token_type == TokenType::TokenInputNumber|| current_token.token_type == TokenType::TokenAdd 
|| current_token.token_type == TokenType::TokenMul|| current_token.token_type == TokenType::TokenDiv);

if (current_token.token_type == TokenType::TokenInputNumber) {
    // 如果当前token类型是输入数字类型, 则直接返回一个操作数token作为一个叶子节点,不再向下递归
    uint32_t start_pos = current_token.start_pos + 1;
    uint32_t end_pos = current_token.end_pos;
    CHECK(end_pos > start_pos);
    CHECK(end_pos <= this->statement_.length());
    const std::string &str_number =
        std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
    return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);
}

// 如果遇到的第一个token是运算符
else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd|| current_token.token_type == TokenType::TokenDiv) {
std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();
current_node->num_index = -int(current_token.token_type);//运算符类型

// 判断运算符之后是否有( left bracket
CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);

// 判断当前需要处理的left token是不是合法类型
// Generate_遇到的第一个token不是数字就是运算符,不能是标点符号
if (left_token.token_type == TokenType::TokenInputNumber
    || left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul|| left_token.token_type == TokenType::TokenDiv) {
// (之后进行向下递归得到@0
current_node->left = Generate_(index);

// 当前的index指向add(@1,@2)中的逗号
index += 1;
CHECK(index < this->tokens_.size());
CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);

// 构建右子树
if (right_token.token_type == TokenType::TokenInputNumber
    || right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul|| right_token.token_type == TokenType::TokenDiv) {
current_node->right = Generate_(index);

// 右括号
CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);

总结:
表达式解析函数逻辑如下:

  1. 遇到数字,直接返回
  2. 遇到运算符,继续
    2.1判断是不是左括号
    2.2解析表达式
    2.3判断是不是逗号
    2.4解析表达式
    2.5判断是不是右括号,返回

TEST

中序遍历打印

static void ShowNodes(const std::shared_ptr<kuiper_infer::TokenNode> &node) {
  if (!node) {
    return;
  }
  // 中序遍历的顺序
  ShowNodes(node->left);

  if (node->num_index < 0) {
    if (node->num_index == -int(kuiper_infer::TokenType::TokenAdd)) {
      LOG(INFO) << "ADD";
    } else if (node->num_index == -int(kuiper_infer::TokenType::TokenMul)) {
      LOG(INFO) << "MUL";
    }
else if (node->num_index == -int(kuiper_infer::TokenType::TokenDiv)) {
      LOG(INFO) << "Div";
    }

  } else {
    LOG(INFO) << "NUM: " << node->num_index;
  }
  
  ShowNodes(node->right);
}
add(@1,@2)打印顺序1 add 2
add(mul(@0,@1),@2)打印顺序0 mul 1 add 2
add(mul(@0,@1),mul(@2,add(@3,@4)))打印顺序0 mul 1 add 2 mul 3 add 4
add(div(@0,@1),@2)打印顺序0 div 1 add 2

打印结果 7.1.png 目前总共通过16个test 7.2.png