章节15:实现函数的定义和调用

43 阅读7分钟

定义函数

函数是程序语言的基本结构,我们定义声明函数的形式 function sayHello(a){ print(a) return a},其中a为函数的形参,相对于函数的执行体是局部变量。函数可以访问全局变量,实现局部变量的功能,通过继承全局的符号表增加局部变量,就实现了函数的局部符号表。函数本身通过符号表存储,使用时通过名称从符号表取得语法节点,执行代码。
函数定义的语法节点:functionDefineAst : function variableAst LBRACKET variableAst (, variableAst)* RBRACKET LBRACE programAst RBRACE , LBRACE为’{‘ , RBRACE为’}‘

调用函数

函数本身就是被整合成一块的程序代码,有了局部符号表就像从根程序执行代码一样。函数调用的形式 sayHello('hello')
函数调用的语法节点:functionCall : variableAst LBRACKET arithAst (, arithAst)* RBRACKET

函数返回

函数返回在函数体中,形式为 return a, 语法节点为 functionReturn : return arithAst

回顾上一个语法图

programAst : (assignAst | printCallAst)*
assignAst : variableAst ASSIGN arithAst
printCallAst : print LBRACKET arithAst RBRACKET
arithAst : termAst ((PLUS|MINUS) termAst)* 
termAst : factorAst ((MUL|DIV) factorAst)* 
factorAst: INTEGER | LBRACKET arithAst RBRACKET | variableAst | stringAst 

新的语法图

programAst : (assignAst | printCallAst | functionDefineAst | functionCall | functionReturnAst)*
functionDefineAst : function variableAst LBRACKET variableAst (, variableAst)* RBRACKET LBRACE programAst RBRACE
functionCallAst : variableAst LBRACKET arithAst (, arithAst)* RBRACKET
functionReturnAst : return arithAst
assignAst : variableAst ASSIGN arithAst
printCallAst : print LBRACKET arithAst RBRACKET
arithAst : termAst ((PLUS|MINUS) termAst)* 
termAst : factorAst ((MUL|DIV) factorAst)* 
factorAst: INTEGER | LBRACKET arithAst RBRACKET | variableAst | stringAst | functionCallAst 

语法图解释

新增三个语法节点,functionDefineAst函数定义,functionCallAst函数调用,functionReturnAst函数返回。functionDefineAst函数定义的函数体本质上是执行程序代码,所以函数体设定为programAst 。programAst程序增加functionDefineAst | functionCall | functionReturnAst , 根程序执行返回没有效果,一般是函数内执行返回。factorAst内增加functionCallAst , 函数调用可返回变量,在赋值的时候使用。确定语法图的逻辑与设计,按照语法图修改代码。

修改词法解析器

// 词性枚举
public enum TokenType {
    INTEGER // 数字
    , PLUS // 加法运算符
    , EOF // 程序结束
    , MINUS // 减法运算符
    , MUL // 乘法运算符
    , DIV // 除法运算符
    , LBRACKET // 左括号
    , RBRACKET // 右括号
    , ID // 变量
    , ASSIGN // 赋值符号=
    , PRINT // 内置打印函数
    , QUOTE // 单引号 '
    , STRING // 字符串
    , RETURN // 返回关键字 return
    , FUNCTION // 关键字 function
    , COLON // 逗号 ,
    , LBRACE // 左花括号 {
    , RBRACE // 右花括号 }
}

// 词法解析器
public class Lexer {
    private String text; // 输入的程序
    private Integer position; // 记录扫描的位置
    private Character currentChar; // 记录当前扫描的字符
    private Map<String , Token> keyWordMap = new HashMap<>(); // 关键字map

    public Token peekToken(){ // 不改变位置获取下一个词法单元
        Integer lastPosition = position;
        Character lastChar = currentChar;
        Token token = getNextToken();
        position = lastPosition;
        currentChar = lastChar;
        return token;
    }

    public Token getNextToken(){  // 获取词法单元
        while(this.currentChar != null){
            if(Character.isDigit(this.currentChar)){
                return this.integer();
            }else if(Character.isWhitespace(currentChar)){
                this.skipWhiteSpace();
            }else if(this.currentChar == '+'){
                Token token = new Token(TokenType.PLUS , "+");
                this.advance();
                return token; 
            }else if(this.currentChar == '-'){
                Token token = new Token(TokenType.MINUS , "-");
                this.advance();
                return token; 
            }else if(this.currentChar == '*'){
                Token token = new Token(TokenType.MUL , "*");
                this.advance();
                return token; 
            }else if(this.currentChar == '/'){
                Token token = new Token(TokenType.DIV , "/");
                this.advance();
                return token; 
            }else if(this.currentChar == '/'){
                Token token = new Token(TokenType.DIV , "/");
                this.advance();
                return token; 
            }else if(this.currentChar == '('){
                Token token = new Token(TokenType.LBRACKET , "(");
                this.advance();
                return token; 
            }else if(this.currentChar == ')'){
                Token token = new Token(TokenType.RBRACKET , ")");
                this.advance();
                return token; 
            }else if(this.currentChar == '{'){
                Token token = new Token(TokenType.LBRACE , "{");
                this.advance();
                return token; 
            }else if(this.currentChar == '}'){
                Token token = new Token(TokenType.RBRACE , "}");
                this.advance();
                return token; 
            }else if(this.currentChar == '='){
                Token token = new Token(TokenType.ASSIGN , "=");
                this.advance();
                return token; 
            }else if(this.currentChar == ','){
                Token token = new Token(TokenType.COLON , ",");
                this.advance();
                return token; 
            }else if(this.currentChar == '\''){
                return this.string();
            }else if(Character.isAlphabetic(currentChar)){
                return variable();
            }else {
                this.error("未知的词法");
            }
        }
        return new Token(TokenType.EOF);
    }
    private Token string() {
        String value = "";
        this.advance();
        while(currentChar != null && currentChar != '\''){
            value += currentChar;
            this.advance();
        }
        this.advance();
        return new Token(TokenType.STRING , value);
    }
    private Token variable(){ // 识别变量
        String value = "";
        while(currentChar != null && Character.isAlphabetic(currentChar)){
            value += currentChar;
            this.advance();
        }
        Token token = keyWordMap.getOrDefault(value, new Token(TokenType.ID, value));
        return token;
    }
    public Token integer(){ // 识别多个数字
        String result = "";
        while(this.currentChar != null && Character.isDigit(this.currentChar)){
            result += this.currentChar;
            this.advance();
        }
        return new Token(TokenType.INTEGER ,Integer.valueOf(result));
    }

    private void skipWhiteSpace(){ // 空格跳过
        while(currentChar != null && Character.isWhitespace(currentChar)){
            this.advance();
        }
    }

    public void advance(){ // 往后走一步
        this.position += 1;
        if(this.position <= this.text.length() - 1){ // 扫描的位置有效
            this.currentChar = text.charAt(this.position);
        }else{ // 扫描完了
            this.currentChar = null;
        }
        
    }
    public void error(String msg){ // 报错函数
        throw new RuntimeException(msg);
    }
    public Lexer(String text) {// 构造器
        this.text = text;
        this.position = 0;
        this.currentChar = text.charAt(this.position);
        keyWordMap.put("print", new Token(TokenType.PRINT , "print"));
        keyWordMap.put("return", new Token(TokenType.RETURN , "return"));
        keyWordMap.put("function", new Token(TokenType.FUNCTION , "function"));
    }
}

语法解析器

@Data
public class FunctionCallAst extends Ast{

    private VariableAst functionName;

    private List<Ast> param = new ArrayList<>();

    public void addParam(Ast ast){
        this.param.add(ast);
    }

}
@Data
public class FunctionDefineAst extends Ast{

    private VariableAst fucntionName;

    private List<VariableAst> param = new ArrayList<>();

    private Ast body ;

    public void addParam(VariableAst variable){
        param.add(variable);
    }
}

@Data
public class FunctionReturnAst extends Ast{

    private Ast result;

    public FunctionReturnAst(Ast result) {
        this.result = result;
    }
}
// 语法解析器
public class Parser {
    private Lexer lexer ; // 词法解析器
    private Token currentToken; // 当前的词法单元
    public Parser(Lexer lexer) {
        this.lexer = lexer;
        this.currentToken = this.lexer.getNextToken();
    }
    public Ast programAst(){ // 程序节点
        // programAst : (assignAst | printCallAst | functionDefineAst | functionCall | functionReturnAst)*
        ProgramAst ast = new ProgramAst();
        while (Arrays.asList(TokenType.ID,TokenType.PRINT,TokenType.FUNCTION,TokenType.RETURN).contains(this.currentToken.getType()) ) {
            if(this.currentToken.getType() == TokenType.ID && lexer.peekToken().getType() == TokenType.LBRACKET){
                ast.add(this.functionCallAst());
            }else if(this.currentToken.getType() == TokenType.ID){
                ast.add(this.assignAst());
            }else if(this.currentToken.getType() == TokenType.PRINT){
                ast.add(this.printCallAst());
            }else if(this.currentToken.getType() == TokenType.FUNCTION){
                ast.add(this.functionDefineAst());
            }else if(this.currentToken.getType() == TokenType.RETURN){
                ast.add(this.functionReturnAst());
            }
        }
        return ast;
    }
    public Ast functionDefineAst(){
        // functionDefineAst : function variableAst LBRACKET variableAst (, variableAst)* RBRACKET LBRACE programAst RBRACE
        FunctionDefineAst functionDefineAst = new FunctionDefineAst();
        this.eat(TokenType.FUNCTION);
        Token name = currentToken;
        this.eat(TokenType.ID);
        this.eat(TokenType.LBRACKET);
        while (currentToken.getType() != TokenType.RBRACKET) {
            Token param = currentToken;
            this.eat(TokenType.ID);
            functionDefineAst.addParam(new VariableAst(param));
            while(currentToken.getType() == TokenType.COLON){
                param = currentToken;
                this.eat(TokenType.ID);
                functionDefineAst.addParam(new VariableAst(param));
            }
        }
        this.eat(TokenType.RBRACKET);
        this.eat(TokenType.LBRACE);
        Ast body = this.programAst();
        this.eat(TokenType.RBRACE); 
        functionDefineAst.setFucntionName(new VariableAst(name));
        functionDefineAst.setBody(body);
        return functionDefineAst;
    }
    public Ast functionCallAst(){
        // functionCallAst : variableAst LBRACKET arithAst (, arithAst)* RBRACKET
        FunctionCallAst functionCallAst = new FunctionCallAst();
        Token name = currentToken;
        this.eat(TokenType.ID);
        this.eat(TokenType.LBRACKET);
        while (currentToken.getType() != TokenType.RBRACKET) {
            functionCallAst.addParam(this.arithAst());
            while(currentToken.getType() == TokenType.COLON){
                functionCallAst.addParam(this.arithAst());
            }
        }
        this.eat(TokenType.RBRACKET);
        functionCallAst.setFunctionName(new VariableAst(name));
        return functionCallAst;
    }
    public Ast functionReturnAst(){
        // functionReturnAst : return arithAst
        this.eat(TokenType.RETURN);
        Ast ast = this.arithAst();
        return new FunctionReturnAst(ast);
    }
    public Ast printCallAst(){
        // printCallAst : print LBRACKET arithAst RBRACKET
        this.eat(TokenType.PRINT);
        this.eat(TokenType.LBRACKET);
        Ast ast = this.arithAst();
        this.eat(TokenType.RBRACKET);
        return new PrintCallAst(ast);
    }

    public Ast assignAst(){
        // assignAst : variableAst ASSIGN arithAst
        Token id = this.currentToken;
        this.eat(TokenType.ID);
        this.eat(TokenType.ASSIGN);
        Ast right = this.arithAst();
        return new AssignAst((String)id.getValue() , right);
    }
    public Ast arithAst(){
        // arithAst : termAst ((PLUS|MINUS) termAst)* 
        Ast node = this.termAst();
        while(Arrays.asList(TokenType.PLUS,TokenType.MINUS).contains(this.currentToken.getType())){
            Token op = this.currentToken;
            if(op.getType() == TokenType.PLUS){
                this.eat(TokenType.PLUS);
            }else if(op.getType() == TokenType.MINUS){
                this.eat(TokenType.MINUS);
            }
            node = new ArithAst(node ,op.getType(),this.termAst());
        }
        return node;
    }
    public Ast termAst(){
        // termAst : factorAst ((MUL|DIV) factorAst)*
        Ast node = this.factorAst();
        while(Arrays.asList(TokenType.MUL,TokenType.DIV).contains(this.currentToken.getType())){
            Token op = this.currentToken;
            if(op.getType() == TokenType.MUL){
                this.eat(TokenType.MUL);
            }else if(op.getType() == TokenType.DIV){
                this.eat(TokenType.DIV);
            }
            node = new TermAst(node ,op.getType(),this.factorAst());
        }
        return node;
    }
    public Ast factorAst(){
        // factorAst: INTEGER | LBRACKET arithAst RBRACKET | variableAst | stringAst | functionCallAst 
        Token left = this.currentToken;
        if(left.getType() == TokenType.INTEGER){
            this.eat(TokenType.INTEGER);
            return new FactorAst((Integer)left.getValue());
        }else if(left.getType() == TokenType.LBRACKET){
            this.eat(TokenType.LBRACKET);
            Ast ast = this.arithAst();
            this.eat(TokenType.RBRACKET);
            return ast;
        }else if(left.getType() == TokenType.ID && lexer.peekToken().getType() == TokenType.LBRACKET){
            return this.functionCallAst();
        }else if(left.getType() == TokenType.ID){
            this.eat(TokenType.ID);
            return new VariableAst(left);
        }else if(left.getType() == TokenType.STRING){
            this.eat(TokenType.STRING);
            return new StringAst(left);
        }
        this.error("语法错误");
        return null;
    }
    public void eat(TokenType tokenType){ // 确认当前的词性是否正确
        if(tokenType == this.currentToken.getType()){
            this.currentToken = this.lexer.getNextToken();
        }else{
            this.error("语法错误");
        }
    }
    public void error(String msg){ // 报错函数
        throw new RuntimeException(msg);
    }
    public Ast parse(){ // 获取语法树
        return this.programAst();
    }
}

修改解释器

// 目标执行器
public class Interpreter {
    private Parser parser; // 语法解析器
    private Map<String , Object> symbolMap = new HashMap<>(); // 符号表
    private Map<String, FunctionDefineAst> functionTable = new HashMap<>();  // 函数符号表
    public Interpreter(Parser parser) {
        this.parser = parser;
    }
    public Object visitProgramAst(Ast ast,Map<String , Object> scope){ // 访问programAst节点
        ProgramAst programAst = (ProgramAst) ast; 
        List<Ast> statementList = programAst.getStatementList();
        for(Ast statment : statementList){
            Object result = this.visit(statment,scope);
            if(statment instanceof FunctionReturnAst){
                return result;
            }
        }
        return null;
    }
    public void visitFunctionDefineAst(Ast ast,Map<String , Object> scope){
        FunctionDefineAst functionDefineAst = (FunctionDefineAst) ast;// 存储函数符号表
        functionTable.put(functionDefineAst.getFucntionName().getValue(), functionDefineAst);
    }
    public Object visitFunctionCallAst(Ast ast,Map<String , Object> scope){
        FunctionCallAst functionCallAst = (FunctionCallAst) ast;
        FunctionDefineAst functionDefineAst = functionTable.get(functionCallAst.getFunctionName().getValue());
        
        List<VariableAst> defineList = functionDefineAst.getParam();
        Map<String , Object> newScope = new HashMap<>(); // 将值绑定到形参,形成函数的局部作用域
        List<Ast> paramList = functionCallAst.getParam();
        for(int i=0;i<paramList.size();i++){
            newScope.put(defineList.get(i).getValue(), this.visit(paramList.get(i), scope));
        }

        return this.visit(functionDefineAst.getBody() , newScope);
    }
    public Object visitFunctionReturnAst(Ast ast,Map<String , Object> scope){
        FunctionReturnAst functionReturnAst = (FunctionReturnAst) ast;
        Object result = this.visit(functionReturnAst.getResult(),scope);
        return result;
    }
    public void visitPrintCallAst(Ast ast,Map<String , Object> scope){
        PrintCallAst printCallAst = (PrintCallAst) ast;
        System.out.println(this.visit(printCallAst.getArithAst(),scope));
    }
    public void visitAssignAst(Ast ast,Map<String , Object> scope){
        AssignAst assignAst = (AssignAst) ast;
        String name = assignAst.getLeftValue();
        Object value =  this.visit(assignAst.getRightValue(),scope);
        scope.put(name,value);
    }
    public Object visitArithAst(Ast ast,Map<String , Object> scope){
        ArithAst arithAst = (ArithAst) ast;
        if(arithAst.getOp() == TokenType.PLUS){
            Object leftObj = this.visit(arithAst.getLeftValue(),scope);
            Object rightObj = this.visit(arithAst.getRightValue(),scope);
            if(leftObj instanceof String || rightObj instanceof String){
                String left = leftObj instanceof String ? (String) leftObj : String.valueOf((Integer)leftObj);
                String right = rightObj instanceof String ? (String) rightObj : String.valueOf((Integer)rightObj);
                return left + right;
            }else {
                return (Integer) leftObj + (Integer)rightObj;
            }
        }else if(arithAst.getOp() == TokenType.MINUS){
            return (Integer)this.visit(arithAst.getLeftValue(),scope) - (Integer)this.visit(arithAst.getRightValue(),scope); // 减法计算
        }
        return null;
    }   
    public Object visitTermAst(Ast ast,Map<String , Object> scope){
        TermAst termAst = (TermAst) ast; 
        if(termAst.getOp() == TokenType.MUL){
            return (Integer)this.visit(termAst.getLeftValue(),scope) * (Integer)this.visit(termAst.getRightValue(),scope); // 乘法计算
        }else if(termAst.getOp() == TokenType.DIV){
            return (Integer)this.visit(termAst.getLeftValue(),scope) / (Integer)this.visit(termAst.getRightValue(),scope); // 除法计算
        }
        return null;
    }
    public Object visitStringAst(Ast ast,Map<String , Object> scope){
        StringAst stringAst = (StringAst) ast;
        return stringAst.getValue();
    }
    public Object visitFactorAst(Ast ast,Map<String , Object> scope){
        FactorAst factorAst = (FactorAst) ast;
        return factorAst.getValue();
    }
    public Object visitVariableAst(Ast ast, Map<String , Object> scope){
        VariableAst variableAst = (VariableAst) ast;
        return scope.get(variableAst.getValue()); // 从符号表获取对应的值
    }
    public Object visit(Ast ast, Map<String , Object> scope){ // 使用反射通过类名调用对应的函数
        String methodName = "visit" + ast.getClass().getSimpleName();
        try {
            Method method = this.getClass().getDeclaredMethod(methodName , Ast.class , Map.class );
            return method.invoke(this, ast , scope);
        } catch (Exception e) {
            e.printStackTrace();
        } 
        return null;
    }
    public void expr() {
        Ast ast = parser.parse(); // 获取语法树
        this.visit(ast, this.symbolMap); // 遍历获取结果
        System.out.println(symbolMap);
    }
}

核心要点

函数定义比较简单,直接存储到函数符号表,在调用时获取。

public void visitFunctionDefineAst(Ast ast,Map<String , Object> scope){
    FunctionDefineAst functionDefineAst = (FunctionDefineAst) ast;// 存储函数符号表
    functionTable.put(functionDefineAst.getFucntionName().getValue(), functionDefineAst);
}

函数调用将调用时的值绑定到函数声明的变量,再将变量参与后续的执行。

public Object visitFunctionCallAst(Ast ast,Map<String , Object> scope){
    FunctionCallAst functionCallAst = (FunctionCallAst) ast;
    FunctionDefineAst functionDefineAst = functionTable.get(functionCallAst.getFunctionName().getValue());

    List<VariableAst> defineList = functionDefineAst.getParam();
    Map<String , Object> newScope = new HashMap<>(); // 将值绑定到形参,形成函数的局部作用域
    List<Ast> paramList = functionCallAst.getParam();
    for(int i=0;i<paramList.size();i++){
        newScope.put(defineList.get(i).getValue(), this.visit(paramList.get(i), scope));
    }
    return this.visit(functionDefineAst.getBody() , newScope);
}

执行测试

LJS.txt
a = 1
b = 1 + 2
c = a + 1
d = a * b + c * 2
f = a * ( b + c ) * 2
e = 'hello world'
print('a的值为:' + a)
print('b的值为:' + b)
print(e + '_' + a)

function sayHello(name){
    print('函数内部打印:' + name)
    return '函数返回:' + name
}

sayHello('hello')
sayHello(e)
result = sayHello(f)
print('result_' + result)

Path resourcePath = Paths.get("study-python","demo","src", "main", "resources").toAbsolutePath();
Lexer lexer = new Lexer(FileUtil.readUtf8String(resourcePath + "\\LJS.txt"));
Parser parser = new Parser(lexer);
Interpreter interpreter = new Interpreter(parser);
interpreter.expr();

控制台打印:

a的值为:1
b的值为:3
hello world_1
函数内部打印:hello
函数内部打印:hello world
函数内部打印:10
result_函数返回:10
{result=函数返回:10, a=1, b=3, c=2, d=7, e=hello world, f=10}