ch06中我们从pnnx计算图转化构建了自己的计算图。尽管我们可以获得每个op的类型和名字,但有一类op尽管我们拿到了但上节没法处理,那就是表达式。
此节我们对表达式类型的算子进行词法解析和语法解析。
如图所示:expr: add(@0,@1)
表达式介绍
pnnx生成的抽象表达式为
add(@0,@1)
add(@0,mul(@1,@2))
add(add(mul(@0,@1),mul(@2,add(add(add(@0,@2),@3),@4))),@5)
即每个运算符跟两个运算数字,其中运算数可以为@数字或者一个完整的表达式如mul(@1,@2)。
词法解析
Token类型
enum class TokenType{
TokenUnknown = -1,
TokenInputNumber = 0,
TokenComma = 1,
TokenAdd = 2,
TokenMul = 3,
TokenLeftBracket = 4,
TokenRightBracket = 5,
TokenDiv = 6,
};
词与节点
// 词的类型与位置
struct Token{
TokenType token_type = TokenType::TokenUnknown;
int32_t start_pos = 0; // 词语出现位置
int32_t end_pos = 0; // 词语结束的位置
Token(TokenType token_type, int32_t start_pos, int32_t end_pos):token_type(token_type), start_pos(start_pos), end_pos(end_pos) {}
};
// 节点的值/类型(num_index为正值, 该节点为数字类型, 值为num_index;
// 小于0,则为运算符, 具体类型为-num_index)
// 叶子节点
struct TokenNode{
int32_t num_index = -1;
std::shared_ptr<TokenNode> left = nullptr;
std::shared_ptr<TokenNode> right = nullptr;
TokenNode(int32_t num_index, std::shared_ptr<TokenNode> left, std::shared_ptr<TokenNode> right);
TokenNode() = default;
};
词法解析,一个存词,一个存字符串
void Tokenizer(bool need_retoken=false);
std::vector<Token> tokens_;
std::vector<std::string> token_strs_;
具体解析流程
// 首先是判断statement_是否为空, 随后删除表达式中的所有空格和制表符.
CHECK(!statement_.empty()) << "The input statement is empty!";
statement_.erase(std::remove_if(statement_.begin(), statement_.end(), [](char c) {
return std::isspace(c);
}), statement_.end()); //?
CHECK(!statement_.empty()) << "The input statement is empty!";
// 解析add、div、mul的思路是一样的
if(c=='a'){
CHECK(i+1<statement_.size()&&statement_.at(i+1)=='d')<< "Parse add token failed, illegal character: " << c;
CHECK(i+2<statement_.size()&&statement_.at(i+1)=='d')<< "Parse add token failed, illegal character: " << c;
Token token(TokenType::TokenAdd, i, i+3);
tokens_.push_back(token);
std::string token_operation = std::string(statement_.begin() + i, statement_.begin() + i + 3);
token_strs_.push_back(token_operation);
i = i + 3;
}
// 解析数字类型
else if(c=='@'){
CHECK(i+1<statement_.size()&&std::isdigit(statement_.at(i+1)))<< "Parse number token failed, illegal character: " << c;
int32_t j = i+1;
for(;j<statement_.size();++j){
if(!std::isdigit(statement_.at(j))){
break;
}
}
Token token(TokenType::TokenInputNumber, i, j);
CHECK(token.start_pos < token.end_pos);
tokens_.push_back(token);
std::string token_input_number = std::string(statement_.begin() + i, statement_.begin() + j);
token_strs_.push_back(token_input_number);
i = j;
}
// 解析 , ( )
else if (c == ',') {
Token token(TokenType::TokenComma, i, i + 1);
tokens_.push_back(token);
std::string token_comma = std::string(statement_.begin() + i, statement_.begin() + i + 1);
token_strs_.push_back(token_comma);
i += 1;
}
语法解析
递归方法
Generate_函数
// Generate_遇到的第一个token不是数字就是运算符,不能是标点符号
CHECK(current_token.token_type == TokenType::TokenInputNumber|| current_token.token_type == TokenType::TokenAdd
|| current_token.token_type == TokenType::TokenMul|| current_token.token_type == TokenType::TokenDiv);
if (current_token.token_type == TokenType::TokenInputNumber) {
// 如果当前token类型是输入数字类型, 则直接返回一个操作数token作为一个叶子节点,不再向下递归
uint32_t start_pos = current_token.start_pos + 1;
uint32_t end_pos = current_token.end_pos;
CHECK(end_pos > start_pos);
CHECK(end_pos <= this->statement_.length());
const std::string &str_number =
std::string(this->statement_.begin() + start_pos, this->statement_.begin() + end_pos);
return std::make_shared<TokenNode>(std::stoi(str_number), nullptr, nullptr);
}
// 如果遇到的第一个token是运算符
else if (current_token.token_type == TokenType::TokenMul || current_token.token_type == TokenType::TokenAdd|| current_token.token_type == TokenType::TokenDiv) {
std::shared_ptr<TokenNode> current_node = std::make_shared<TokenNode>();
current_node->num_index = -int(current_token.token_type);//运算符类型
// 判断运算符之后是否有( left bracket
CHECK(this->tokens_.at(index).token_type == TokenType::TokenLeftBracket);
// 判断当前需要处理的left token是不是合法类型
// Generate_遇到的第一个token不是数字就是运算符,不能是标点符号
if (left_token.token_type == TokenType::TokenInputNumber
|| left_token.token_type == TokenType::TokenAdd || left_token.token_type == TokenType::TokenMul|| left_token.token_type == TokenType::TokenDiv) {
// (之后进行向下递归得到@0
current_node->left = Generate_(index);
// 当前的index指向add(@1,@2)中的逗号
index += 1;
CHECK(index < this->tokens_.size());
CHECK(this->tokens_.at(index).token_type == TokenType::TokenComma);
// 构建右子树
if (right_token.token_type == TokenType::TokenInputNumber
|| right_token.token_type == TokenType::TokenAdd || right_token.token_type == TokenType::TokenMul|| right_token.token_type == TokenType::TokenDiv) {
current_node->right = Generate_(index);
// 右括号
CHECK(this->tokens_.at(index).token_type == TokenType::TokenRightBracket);
总结:
表达式解析函数逻辑如下:
- 遇到数字,直接返回
- 遇到运算符,继续
2.1判断是不是左括号
2.2解析表达式
2.3判断是不是逗号
2.4解析表达式
2.5判断是不是右括号,返回
TEST
中序遍历打印
static void ShowNodes(const std::shared_ptr<kuiper_infer::TokenNode> &node) {
if (!node) {
return;
}
// 中序遍历的顺序
ShowNodes(node->left);
if (node->num_index < 0) {
if (node->num_index == -int(kuiper_infer::TokenType::TokenAdd)) {
LOG(INFO) << "ADD";
} else if (node->num_index == -int(kuiper_infer::TokenType::TokenMul)) {
LOG(INFO) << "MUL";
}
else if (node->num_index == -int(kuiper_infer::TokenType::TokenDiv)) {
LOG(INFO) << "Div";
}
} else {
LOG(INFO) << "NUM: " << node->num_index;
}
ShowNodes(node->right);
}
add(@1,@2)打印顺序1 add 2
add(mul(@0,@1),@2)打印顺序0 mul 1 add 2
add(mul(@0,@1),mul(@2,add(@3,@4)))打印顺序0 mul 1 add 2 mul 3 add 4
add(div(@0,@1),@2)打印顺序0 div 1 add 2
打印结果
目前总共通过16个test