手撕JDK动态代理源码

106 阅读2分钟

1、动态代理原理

可以参阅:聊聊动态代理 - 掘金 (juejin.cn)

2、动态生成代理类过程

需要代理的目标接口

package com.hdu;

public interface ITeacherService {

    String findTeacherByName(String teacherName);

    String findTeacherByNameAndId(String teacherName, String teacherId);
}

生成的代理类

package com.hdu;


import com.hdu.ITeacherService;

import com.hdu.InvocationHandler;

import java.lang.reflect.Method;

import com.hdu.Proxy;

import java.lang.reflect.UndeclaredThrowableException;

public final class $Proxy_ITeacherService extends Proxy implements ITeacherService {

    private static Method METHOD_EQUALS;
    private static Method METHOD_TO_STRING;
    private static Method METHOD_HASH_CODE;
    private static Method METHOD_FINDTEACHERBYNAME;
    private static Method METHOD_FINDTEACHERBYNAMEANDID;

    static {
        try {
            METHOD_EQUALS = Class.forName("java.lang.Object").getMethod("equals", Class.forName("java.lang.Object"));
            METHOD_TO_STRING = Class.forName("java.lang.Object").getMethod("toString");
            METHOD_HASH_CODE = Class.forName("java.lang.Object").getMethod("hashCode");
            METHOD_FINDTEACHERBYNAME = Class.forName("com.hdu.ITeacherService").getMethod("findTeacherByName", Class.forName("java.lang.String"));
            METHOD_FINDTEACHERBYNAMEANDID = Class.forName("com.hdu.ITeacherService").getMethod("findTeacherByNameAndId", Class.forName("java.lang.String"), Class.forName("java.lang.String"));
        } catch (NoSuchMethodException var2) {
            throw new NoSuchMethodError(var2.getMessage());
        } catch (ClassNotFoundException var3) {
            throw new NoClassDefFoundError(var3.getMessage());
        }
    }

    public $Proxy_ITeacherService(InvocationHandler var1) {
        super(var1);
    }

    public final int hashCode() {
        try {
            return (Integer) super.h.invoke(this, METHOD_HASH_CODE, (Object[]) null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String toString() {
        try {
            return (String) super.h.invoke(this, METHOD_TO_STRING, (Object[]) null);
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final boolean equals(Object var1) {
        try {
            return (Boolean) super.h.invoke(this, METHOD_EQUALS, new Object[]{var1});
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String findTeacherByName(String arg1) {
        try {
            return (String) super.h.invoke(this, METHOD_FINDTEACHERBYNAME, new Object[]{arg1});
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

    public final String findTeacherByNameAndId(String arg1, String arg2) {
        try {
            return (String) super.h.invoke(this, METHOD_FINDTEACHERBYNAMEANDID, new Object[]{arg1, arg2});
        } catch (RuntimeException | Error var2) {
            throw var2;
        } catch (Throwable var3) {
            throw new UndeclaredThrowableException(var3);
        }
    }

}

只要能生成这个代理类,只需要把这个代理类加载到JVM里,就可以进行动态代理了。

3、怎么生成代理类?

其实可以发现代理类的代码都是模板性质的,所以我们只需要拼接字符串 生成上面的动态类(字符串形式),然后再把这个字符串代表的类加载到JVM内存里面就好了。下面的代码是如何拼接字符串。

package com.hdu.utils;

import com.hdu.InvocationHandler;
import com.hdu.Proxy;

import java.lang.reflect.Method;
import java.lang.reflect.UndeclaredThrowableException;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static java.lang.String.format;

public class JavaCodeGenerator {


    private static final String LINE_BREAK = "\n";
    private static final String SEMICOLON = ";";

    public static String generateProxyJavaCode(Class<?> targetClass)
        throws IllegalArgumentException {
        final StringBuilder javaCode = new StringBuilder();

        generatePackageInfo(targetClass, javaCode);

        generateLineBreak(javaCode, 2);

        generateImportInfo(targetClass, javaCode);

        generateLineBreak(javaCode, 1);

        javaCode.append(
            format
            (
                "public final class $Proxy_%s extends Proxy implements %s {",
                targetClass.getSimpleName(),
                targetClass.getSimpleName()
            )
        );

        generateLineBreak(javaCode, 2);

        generateTargetClassMethodDeclare(targetClass, javaCode);

        generateLineBreak(javaCode, 1);

        generateConstructor(javaCode, targetClass);

        generateLineBreak(javaCode, 2);

        generateHashCode(javaCode);

        generateLineBreak(javaCode, 2);

        generateToString(javaCode);

        generateLineBreak(javaCode, 2);

        generateEquals(javaCode);

        generateLineBreak(javaCode, 2);

        generateTargetClassMethodImpl(targetClass, javaCode);

        javaCode.append("}");

        return javaCode.toString();
    }

    private static void generateTargetClassMethodImpl(Class<?> targetClass, StringBuilder javaCode) {
        Arrays.stream(targetClass.getDeclaredMethods())
            .filter(method -> !method.getDeclaringClass().equals(Object.class))
            .forEach(
            m -> {
                final AtomicInteger argIndex = new AtomicInteger(1);
                String args = Arrays.stream(m.getParameterTypes())
                    .map(p -> p.getSimpleName() + " " + format("arg%s", argIndex.getAndIncrement()))
                    .collect(Collectors.joining(", "));
                javaCode.append
                    (
                    format
                    ("   public final %s %s(%s) {",
                     m.getReturnType().getSimpleName(),
                     m.getName(),
                     args
                    )
                )
                    .append(LINE_BREAK)
                    .append("      try {")
                    .append(LINE_BREAK);
                argIndex.set(1);
                String argsPlaceHolder = Arrays.stream(m.getParameterTypes())
                    .map(p -> format("arg%s", argIndex.getAndIncrement()))
                    .collect(Collectors.joining(", "));
                javaCode.append(
                    format("        return (%s) super.h.invoke(this, METHOD_%s, new Object[]{%s});",
                           m.getReturnType().getSimpleName(),
                           m.getName().toUpperCase(),
                           argsPlaceHolder
                          )
                );

                javaCode.append(LINE_BREAK);
                javaCode.append("      } catch (RuntimeException | Error var2) {")
                    .append(LINE_BREAK)
                    .append("          throw var2;")
                    .append(LINE_BREAK)
                    .append("      } catch (Throwable var3) {")
                    .append(LINE_BREAK)
                    .append("          throw new UndeclaredThrowableException(var3);")
                    .append(LINE_BREAK)
                    .append("      }")
                    .append(LINE_BREAK)
                    .append("   }")
                    .append(LINE_BREAK)
                    .append(LINE_BREAK);
            }
        );
    }

    private static void generateLineBreak(StringBuilder javaCode, int len) {
        for (int i = 0; i < len; i++) {
            javaCode.append(LINE_BREAK);
        }
    }

    private static void generateImportInfo(Class<?> targetClass, StringBuilder javaCode) {
        // import com.hdu.IStudentService;
        javaCode.append("import ")
            .append(targetClass.getName())
            .append(SEMICOLON)
            .append(LINE_BREAK);

        javaCode.append(LINE_BREAK);

        // import com.hdu.InvocationHandler;
        javaCode.append("import ")
            .append(InvocationHandler.class.getName())
            .append(SEMICOLON)
            .append(LINE_BREAK);

        // import java.lang.reflect.Method;
        javaCode.append("import ")
            .append(Method.class.getName())
            .append(SEMICOLON)
            .append(LINE_BREAK);

        // import com.hdu.Proxy;
        javaCode.append("import ")
            .append(Proxy.class.getName())
            .append(SEMICOLON)
            .append(LINE_BREAK);

        // import java.lang.reflect.UndeclaredThrowableException;
        javaCode.append("import ")
            .append(UndeclaredThrowableException.class.getName())
            .append(SEMICOLON)
            .append(LINE_BREAK);
    }

    private static void generateHashCode(StringBuilder javaCode) {
        javaCode.append("   public final int hashCode() {").append(LINE_BREAK)
            .append("      try {")
            .append(LINE_BREAK)
            .append("          return (Integer) super.h.invoke(this, METHOD_HASH_CODE, (Object[]) null);")
            .append(LINE_BREAK)
            .append("      } catch (RuntimeException | Error var2) {")
            .append(LINE_BREAK)
            .append("          throw var2;")
            .append(LINE_BREAK)
            .append("      } catch (Throwable var3) {")
            .append(LINE_BREAK)
            .append("          throw new UndeclaredThrowableException(var3);")
            .append(LINE_BREAK)
            .append("      }")
            .append(LINE_BREAK)
            .append("  }");
    }

    private static void generateToString(StringBuilder javaCode) {
        javaCode.append("   public final String toString()")
            .append(" {").append(LINE_BREAK)
            .append("      try {")
            .append(LINE_BREAK)
            .append("          return (String) super.h.invoke(this, METHOD_TO_STRING, (Object[]) null);")
            .append(LINE_BREAK)
            .append("      } catch (RuntimeException | Error var2) {")
            .append(LINE_BREAK)
            .append("          throw var2;")
            .append(LINE_BREAK)
            .append("      } catch (Throwable var3) {")
            .append(LINE_BREAK)
            .append("          throw new UndeclaredThrowableException(var3);")
            .append(LINE_BREAK)
            .append("      }")
            .append(LINE_BREAK)
            .append("  }");
    }

    private static void generateEquals(StringBuilder javaCode) {
        javaCode.append("   public final boolean equals(Object var1)")
            .append(" {").append(LINE_BREAK)
            .append("      try {")
            .append(LINE_BREAK)
            .append("          return (Boolean) super.h.invoke(this, METHOD_EQUALS, new Object[]{var1});")
            .append(LINE_BREAK)
            .append("      } catch (RuntimeException | Error var2) {")
            .append(LINE_BREAK)
            .append("          throw var2;")
            .append(LINE_BREAK)
            .append("      } catch (Throwable var3) {")
            .append(LINE_BREAK)
            .append("          throw new UndeclaredThrowableException(var3);")
            .append(LINE_BREAK)
            .append("      }")
            .append(LINE_BREAK)
            .append("  }");
    }

    private static void generatePackageInfo(Class<?> targetClass, StringBuilder javaCode) {
        // package com.hdu;
        String packageName = targetClass.getPackage().getName();
        javaCode.append("package ")
            .append(packageName)
            .append(SEMICOLON)
            .append(LINE_BREAK);
    }


    private static void generateConstructor(StringBuilder javaCode, Class<?> targetClass) {
        javaCode.append(format("   public $Proxy_%s(InvocationHandler var1) {", targetClass.getSimpleName()))
            .append(LINE_BREAK);
        javaCode.append("       super(var1);").append(LINE_BREAK);
        javaCode.append("   }");
    }

    private static void generateTargetClassMethodDeclare(Class<?> targetClass, StringBuilder javaCode) {
        javaCode.append("   private static Method METHOD_EQUALS").append(SEMICOLON).append(LINE_BREAK);
        javaCode.append("   private static Method METHOD_TO_STRING").append(SEMICOLON).append(LINE_BREAK);
        javaCode.append("   private static Method METHOD_HASH_CODE").append(SEMICOLON).append(LINE_BREAK);

        Arrays.stream(targetClass.getDeclaredMethods())
            .filter(method -> !method.getDeclaringClass().equals(Object.class))
            .forEach(
            method -> javaCode
            .append("   private static Method METHOD_")
            .append(method.getName().toUpperCase())
            .append(SEMICOLON)
            .append(LINE_BREAK)
        );

        javaCode.append(LINE_BREAK);

        javaCode.append("   static {").append(LINE_BREAK)
            .append("      try {")
            .append(LINE_BREAK);

        javaCode.append("          METHOD_EQUALS = Class.forName(\"java.lang.Object\").getMethod(\"equals\", Class.forName(\"java.lang.Object\"))")
            .append(SEMICOLON)
            .append(LINE_BREAK);

        javaCode.append("          METHOD_TO_STRING = Class.forName(\"java.lang.Object\").getMethod(\"toString\")")
            .append(SEMICOLON)
            .append(LINE_BREAK);

        javaCode.append("          METHOD_HASH_CODE = Class.forName(\"java.lang.Object\").getMethod(\"hashCode\")")
            .append(SEMICOLON)
            .append(LINE_BREAK);

        Arrays.stream(targetClass.getDeclaredMethods())
            .filter(method -> !method.getDeclaringClass().equals(Object.class))
            .forEach(
            method -> {
                javaCode
                    .append("          METHOD_")
                    .append(method.getName().toUpperCase())
                    .append(" = ")
                    .append(format("Class.forName(\"%s\")", targetClass.getName()))
                    .append(format(".getMethod(\"%s\"", method.getName()));

                if (method.getParameterCount() > 0) {
                    javaCode.append(", ");
                    String argTypes = Arrays.stream(method.getParameterTypes())
                        .map(p -> format("Class.forName(\"%s\")", p.getName()))
                        .collect(Collectors.joining(", "));
                    javaCode.append(argTypes);
                    javaCode.append(")");
                } else {
                    javaCode.append(")");
                }

                javaCode.append(SEMICOLON)
                    .append(LINE_BREAK);
            }
        );

        javaCode.append("      }").append(" catch (NoSuchMethodException var2) {").append(LINE_BREAK);
        javaCode.append("         throw new NoSuchMethodError(var2.getMessage());").append(LINE_BREAK);
        javaCode.append("      }").append(" catch (ClassNotFoundException var3) { ").append(LINE_BREAK);
        javaCode.append("         throw new NoClassDefFoundError(var3.getMessage());").append(LINE_BREAK);
        javaCode.append("      }").append(LINE_BREAK);
        javaCode.append("   }").append(LINE_BREAK);
    }

}

4、源码

proxy_demo: proxy_demo (gitee.com)