C++: Visitor Pattern Dynamic vs. Static

238 阅读4分钟

Let us use an example to drive our thinking on Visitor Pattern.

Given the following Context Free Grammar (Non-terminals, Alphatbet, Production Rules, Start Symbol) for simple arithmetics with addition, multiplication and parenthesis.

Non-terminals = {
Expr, Term, Factor, Nat
}

Alphabet = {
Natural Numbers (0, 1, 2, 3 ...),
+, *, (, )
}

Production Rules = {
Expr -> Term | Expr + Term,
Term -> Factor | Term * Factor,
Factor -> Nat | ( Expr )
}


Start Symbol = Expr

We can parse  1 + 2 * (3 + 4) as an Abstract Syntax Tree represented as:

+ (1) (* (2) (+ (3) (4)))

It is a tree-like struct, therefore we can represent the nodes in this expression tree in C++ classes as: (use struct to code more concise by let all members to public)

struct Node;
struct Add: public Node {
  Node* left, Node* right;
};
struct Mul: public Node {
  Node* left, Node* right;
};
struct Num: public Node {
  int value;
};

auto expr = Add( Num(1), Mul( Num(2), Add( Num(3), Num(4) ) ) );

To make this example simple and focus on demonstrating Visitor Patter. We don't talk about Lexing and Parsing, but construct tree struct of that expression directly.

Now, we want the value of this expression, therefore we need to add an getValue method for those different kinds of nodes.

struct Node {
  virtual int getVal() = 0;
};
struct Add: public Node {
  ...
  int getVal() override {return left->getVal() + right->getVal();}
};
struct Mul: public Node {
  ...
  int getValue() override { return left->getVal() * right->getVal();}
};
struct Num: public Node {
  ...
  int getVal() override {return value;}
};

expr.getVal();

 So far so good, we can get the value of the expression in terms of its sub-expressions' values.

What if we want to typecheck of this expression, even thought we just have one type int, by pretenting we have different types and the operands of addition or multiplication should have the same. Then we add getType method to each node as:

enum class Type {
  Integer, Char, Float, Double, Unknown
};

struct Node {
...
  virtual Type getType() = 0;
};

struct Add: public Node {
  ...
  Type getType() override {
    return left->getType()==right->getType() ? left->getType() : Unknown;}
};
struct Mul: public Node {
  ...
  Type getType() override {
   return left->getType()==right->getType() ? left->getType() : Unknown;}
};
struct Num: public Node {
  ...
  Type getType() override {return Integer;}
};

expr.getType();

What if we want to pretty print the expression, we can add one more method on each node to do it.

What if we are not the class author, we can not modify those classes, how can we add new operation on each components of the structure object, i.e. each nodes of the expression. And Visitor Pattern comes to rescure.

We encapsulate new operation in a visitor, and each node (possibly of different classes) will accept that visitor operating on it.

struct Add; struct Mul; struct Num;
struct visitor {
  virtual void visit(Add*) = 0;
  virtual void visit(Mul*) = 0;
  virtual void visit(Num*) = 0;
};
struct Node { virtual void accept(Visitor*) = 0; };
struct Add: public Node {
  Node* left, Node* right;
  void accept(Visitor* v) override {v->visit(this);}
};
struct Mul: public Node {
  Node* left, Node* right;
  void accept(Visitor* v) override {v->visit(this);}
};
struct Num: public Node {
  int value;
  void accept(Visitor* v) override { v->visit(this);}
};

struct GetValue: public Visitor {
  int cur = 0;
  void visit(Add* a) override {
    a->left->accept(this);    int left = this->cur;
    a->right->accept(this);    int right = this->cur;
    this->cur = left + right;
  }
  void visit(Mul* m) override {
    a->left->accept(this);    int left = this->cur;
    a->right->accept(this);    int right = this->cur;
    this->cur = left * right;
  }
  void visit(Num* n) override { this->cur = n->value; }
};

enum Type {
  Integer, Char, Float, Double, Unknown
};
struct TypeCheck: public Visitor {
  Type cur = Unknown;
  void visit(Add* a) override {
    a->left->accept(this); auto left = this->cur;
    a->right->accept(this); auto right = this->cur;
    this->cur = left == right ? left : Unknown;
  }
  void visit(Mul* m) override {
    a->left->accept(this); auto left = this->cur;
    a->right->accept(this); auto right = this->cur;
    this->cur = left == right ? left : Unknown;
  }
  void visit(Num* n) override { this->cur = Unknown; }
};

auot expr = Add( Num(1), Mul( Num(2), Add( Num(3), Num(4) ) ) );
auto getValue = GetValue();
auto typeCheck = TypeCheck();
expr.accept(&getValue);
getValue.cur;
expr.accept(&typeCheck);
typeCheck.cur;

By using visitor pattern, we leave the node classes untouched, whenever we want a new operation, we make a new visitor type to encapsulate that operation and traver through the structure. Moreover, we can pull the logic of traversal in structure itself if every visitor use the same kind of traversal to reduce duplicate traversal logic of different visitor kind.

One more thing about this visitor implementation is that the result of each visit is stored in visitor state.

That is kinda cool, right. But this implementation requires double-dispatch mechanism, that is different nodes calls visit method of different visitors. In LLVM, I encountered another way to implement Visitor Pattern somehow statically, without virtual function. Let us dive in to reimplement this example in "LLVM Visitor" way.

struct Node {
  enum NodeKind{NK_Add, NK_Mul, NK_Num}
private:
  NodeKind kind;
public:
  NodeKind getKind() const { return kind; }
};

struct Add: public Node {
  Node* left, Node* right;
  Add(Node* l, Node* r): Node(NK_Add){}
};
struct Mul: public Node {
  Node* left, Node* right;
  Mul(Node* l, Node* r): Node(NK_Mul){}
};
struct Num: public Node {
  int value;
  Num(int v): Node(NK_Num), {}
};

template<class SubClass>
struct NodeVisitor {
  void visit(Node* n) {
    static_assert(std::is_base_of<NodeVisitor, SubClass>::value,
      "Must pass the derived type to this template!");
    auto kind = n->getKind();
    switch(kind) {
    case NK_Add: static_cast<SubClass*>(this)->visit(static_cast<Add*>(n));break;
    case NK_Mul: static_cast<SubClass*>(this)->visit(static_cast<Mul*>(n));break;
    case NK_Num: static_cast<SubClass*>(this)->visit(static_cast<Num*>(n));break;
    default: assert(false&&"unreachable")
    }
  }

  void visit(Add*){}
  void visit(Mul*){}
  void visit(Num*){}
};

struct GetValue: public NodeVisitor<GetValue> {
  int cur = 0;
  void visit(Add* a) {
    this->visit(a->left); int left = this->cur;
    this->visit(a->right); int right = this->cur;
    this->cur = left + right;
  }
  void visit(Mul* a) {    
    this->visit(a->left); int left = this->cur;
    this->visit(a->right); int right = this->cur;
    this->cur = left * right;
  }
  void visit(Num* n) { this->cur = n->value;}
};

enum Type {
  Integer, Char, Float, Double, Unknown
};
struct TypeCheck: public NodeVisitor<TypeCheck> {
  Type cur = Unknown;
  void visit(Add* a) {
    this->visit(a->left); auto left = this->cur;
    this->visit(a->right); auto right = this->cur;
    this->cur = left == right ? left : Unknown;
  }
  void visit(Mul* a) {    
    this->visit(a->left); auto left = this->cur;
    this->visit(a->right); auto right = this->cur;
    this->cur = left == right ? left : Unknown;
  }
  void visit(Num* n) { this->cur = Integer;}
};

auot expr = Add( Num(1), Mul( Num(2), Add( Num(3), Num(4) ) ) );
auto getValue = GetValue();
auto typeCheck = TypeCheck();
getValue.visit(&expr);
getValue.cur;
typeCheck.visit(&expr);
typeCheck.cur;

As you see, there is no virtual functin on above code snippet. The traversal logic is embeded in Visitor Template, which use CRTP to call its subclass responding method by function overloading, not overriding. And the virtual mechanism is used to call the function of the object, not the function of the type. And LLVM Visitor use a kind member data to identify which type of the given object, then it can dispatch statically to the method of that type.

That is sort of Static Visitor Pattern.