基于ASM实现简易版 cglib 动态代理

933 阅读3分钟

上一篇我们就cglib生成的类文件进行分析,大概已经清楚cglib实现动态代理的原理 cglib动态代理实现原理

今天接着来用ASM实现一下cglib的动态代理

首先我们定义几个类和接口

Enhancer

  • 提供设置代理基类方法setSuperClass
  • 提供设置回调类型的方法setMethodInterceptor
  • 提供生成动态实例的方法create
package simple;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;

public class Enhancer {
    private Class<?> superClass;
    private MethodInterceptor methodInterceptor;

    /**
     * 设置父类
     *
     * @param superClass
     */
    public void setSuperClass(final Class<?> superClass) {
        this.superClass = superClass;
    }

    /**
     * 设置回调方法实例
     *
     * @param methodInterceptor
     */
    public void setMethodInterceptor(final MethodInterceptor methodInterceptor) {
        this.methodInterceptor = methodInterceptor;
    }

    /**
     * 创建动态实例对象
     *
     * @return
     */
    public Object create() throws IOException {
        if (methodInterceptor == null) {
            try {
                return superClass.newInstance();
            } catch (InstantiationException | IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        String className = "$ASMProxy";
        byte[] codeBytes = EnhancerFactory.generate(className, superClass);
        //使用自定义类加载器加载字节码
        ASMClassLoader asmClassLoader = new ASMClassLoader();
        asmClassLoader.add(className, codeBytes);
        try {
            Class<?> aClass = asmClassLoader.loadClass(className);
            Constructor<?> constructor = aClass.getConstructor(MethodInterceptor.class);
            return constructor.newInstance(methodInterceptor);
        } catch (ClassNotFoundException | NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
            e.printStackTrace();
        }
        return null;
    }
}

MethodInterceptor

回调类型接口MethodInterceptor,定义接口方法intercept

package simple;

import java.lang.reflect.Method;

public interface MethodInterceptor {
    Object intercept(Object obj, Method method, Object[] args, Method proxy) throws Throwable;
}
  • obj:生成的代理类实例
  • method:原基类方法
  • args:调用方法的参数数组
  • proxy:代理方法,可通过其调用原方法

生成代理类的使用方法:

package simple.test;

import simple.Enhancer;

public class Main {
    public static void main(String[] args) throws Throwable {
        Enhancer enhancer = new Enhancer();
        //设置需要代理的类
        enhancer.setSuperClass(UserService.class);
        //设置回调类型,这里处理代理逻辑
        enhancer.setMethodInterceptor(new UserMethodInterceptor());
        //生成代理类实例
        UserService service = (UserService) enhancer.create();
        
        System.out.println(service.getClass().getName());
        System.out.println(service.login("admin", "admin"));
        System.out.println(service.login("admin", "admin1"));
    }
}

我们先对比一下基类和最终生成的代理类

  • 需要代理的基类
package simple.test;

public class UserService {
    public boolean login(String username, String password) throws Throwable {
        return "admin".equals(username) && "admin".equals(password);
    }
}
  • 生成的代理类
import java.lang.reflect.Method;
import simple.MethodInterceptor;
import simple.test.UserService;

public class $ASMProxy extends UserService {
    private MethodInterceptor methodInterceptor;
    private static Method _METHOD_login0 = Class.forName("simple.test.UserService").getMethod("login", Class.forName("java.lang.String"), Class.forName("java.lang.String"));
    private static Method _METHOD_ASM_login0 = Class.forName("$ASMProxy").getMethod("_asm_login_0", Class.forName("java.lang.String"), Class.forName("java.lang.String"));

    public $ASMProxy(MethodInterceptor var1) {
        this.methodInterceptor = var1;
    }

    public boolean _asm_login_0(String var1, String var2) throws Throwable {
        return super.login(var1, var2);
    }

    public boolean login(String var1, String var2) throws Exception {
        return (Boolean)this.methodInterceptor.intercept(this, _METHOD_login0, new Object[]{var1, var2}, _METHOD_ASM_login0);
    }
}

代理类做的事情:

  • 提供有参构造,参数类型是MethodInterceptor
  • 静态字段,存储原方法和代理的方法
  • 生成方法login,调用MethodInterceptor的intercept方法
  • 生成暂存方法_asm_login_0,用于调用原方法逻辑login

好了,已经清楚生成的代理类的样子,接下来通过ASM框架来生成代理类:

看回Enhancer类,实现create()生成代理类的逻辑

package simple;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;

public class Enhancer {
    private Class<?> superClass;
    private MethodInterceptor methodInterceptor;

    /**
     * 设置父类
     *
     * @param superClass
     */
    public void setSuperClass(final Class<?> superClass) {
        this.superClass = superClass;
    }

    /**
     * 设置回调方法实例
     *
     * @param methodInterceptor
     */
    public void setMethodInterceptor(final MethodInterceptor methodInterceptor) {
        this.methodInterceptor = methodInterceptor;
    }

    /**
     * 创建动态实例对象
     *
     * @return
     */
    public Object create() throws IOException {
        if (methodInterceptor == null) {
            try {
                return superClass.newInstance();
            } catch (InstantiationException | IllegalAccessException e) {
                e.printStackTrace();
            }
        }
        String className = "$ASMProxy";
        //生成代理类字节数组
        byte[] codeBytes = EnhancerFactory.generate(className, superClass);
        //使用自定义类加载器加载字节码
        ASMClassLoader asmClassLoader = new ASMClassLoader();
        asmClassLoader.add(className, codeBytes);
        try {
            Class<?> aClass = asmClassLoader.loadClass(className);
            Constructor<?> constructor = aClass.getConstructor(MethodInterceptor.class);
            return constructor.newInstance(methodInterceptor);
        } catch (ClassNotFoundException | NoSuchMethodException | InvocationTargetException | InstantiationException | IllegalAccessException e) {
            e.printStackTrace();
        }
        return null;
    }
}

create()这里通过调用EnhancerFactory.generate(className, superClass)来生成代理类字节数组。

来实现EnhancerFactory

EnhancerFactory

总的来说四个步骤:

  • 实现<init>方法
  • 添加静态字段
  • 实现<clinit>方法
  • 生成基类方法和暂存方法
package simple;

import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;

import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class EnhancerFactory {
    public static byte[] generate(String proxyClassName, Class<?> superClass) {
        ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS);
        //版本号、访问标志、类名。。。
        cw.visit(Opcodes.V1_8, Opcodes.ACC_PUBLIC, proxyClassName, null, Type.getInternalName(superClass), null);
        // <init>
        createInit(cw, superClass, proxyClassName);
        // 静态字段
        addStaticFields(cw, superClass);
        // <clinit>
        addClinit(cw, superClass, proxyClassName);
        // 实现方法
        addSuperMethodImpl(cw, superClass, proxyClassName);
        cw.visitEnd();
        return cw.toByteArray();
    }
}

实现<init>方法

生成带参数构造方法,参数类型MethodInterceptor

public $ASMProxy(MethodInterceptor var1) {
    this.methodInterceptor = var1;
}
private static void createInit(ClassWriter cw, Class<?> superClass, String proxyClassName) {
    cw.visitField(Opcodes.ACC_PRIVATE, "methodInterceptor", Type.getDescriptor(MethodInterceptor.class), null, null);
    MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(" + Type.getDescriptor(MethodInterceptor.class) + ")V", null, null);
    mv.visitCode();
    //将this入栈
    mv.visitVarInsn(Opcodes.ALOAD, 0);
    // super()
    mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(superClass), "<init>", "()V", false);
    //将this 参数入栈
    mv.visitVarInsn(Opcodes.ALOAD, 0);
    mv.visitVarInsn(Opcodes.ALOAD, 1);
    //赋值字段
    mv.visitFieldInsn(Opcodes.PUTFIELD, proxyClassName.replace('.', '/'), "methodInterceptor", Type.getDescriptor(MethodInterceptor.class));
    // 返回
    mv.visitInsn(Opcodes.RETURN);
    mv.visitMaxs(2, 2);
    mv.visitEnd();
}

添加静态字段

生成静态字段,存储方法调用

private static void addStaticFields(ClassWriter cw, Class<?> superClass) {
    Method[] methods = getMethods(superClass);
    for (int i = 0; i < methods.length; i++) {
        String fieldName = "_METHOD_" + methods[i].getName() + i;
        String asmFieldName = "_METHOD_ASM_" + methods[i].getName() + i;
        cw.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, fieldName, Type.getDescriptor(Method.class), null, null);
        cw.visitField(Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC, asmFieldName, Type.getDescriptor(Method.class), null, null);
    }

}

private static final List<String> FILTER_METHOD_NAMES = newArrayList();

private static List<String> newArrayList() {
    List<String> list = new ArrayList<>();
    list.add("wait");
    list.add("equals");
    list.add("toString");
    list.add("hashCode");
    list.add("getClass");
    list.add("notify");
    list.add("notifyAll");
    return list;
}
//过滤不需要代理的方法
private static Method[] getMethods(Class<?> superClass) {
    return Arrays.stream(superClass.getMethods()).filter(it -> !FILTER_METHOD_NAMES.contains(it.getName()) && it.getModifiers() != Modifier.FINAL).toArray(Method[]::new);
}

实现<clinit>方法

给生成的静态字段赋值

private static void addClinit(ClassWriter cw, Class<?> superClass, String proxyClassName) {
        MethodVisitor mv = cw.visitMethod(Opcodes.ACC_STATIC, "<clinit>", "()V", null, null);
        mv.visitCode();
        Method[] methods = getMethods(superClass);
        for (int i = 0; i < methods.length; i++) {
            generateMethod(superClass, proxyClassName, mv, methods[i], i);
            generateASMMethod(superClass, proxyClassName, mv, methods[i], i);
        }
        mv.visitInsn(Opcodes.RETURN);
        mv.visitMaxs(2, 2);
        mv.visitEnd();
    }

    private static void generateMethod(Class<?> superClass, String proxyClassName, MethodVisitor mv, Method method, int i) {
        String fieldName = "_METHOD_" + method.getName() + i;
        mv.visitLdcInsn(superClass.getName());
        mv.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class), "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
        mv.visitLdcInsn(method.getName());
        if (method.getParameterCount() == 0) {
            mv.visitInsn(Opcodes.ACONST_NULL);
        } else {
            switch (method.getParameterCount()) {
                case 1:
                    mv.visitInsn(Opcodes.ICONST_1);
                    break;
                case 2:
                    mv.visitInsn(Opcodes.ICONST_2);
                    break;
                case 3:
                    mv.visitInsn(Opcodes.ICONST_3);
                    break;
                default:
                    mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                    break;
            }
            mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Class.class));
            for (int paramIndex = 0; paramIndex < method.getParameterTypes().length; paramIndex++) {
                Class<?> parameter = method.getParameterTypes()[paramIndex];
                mv.visitInsn(Opcodes.DUP);
                switch (paramIndex) {
                    case 0:
                        mv.visitInsn(Opcodes.ICONST_0);
                        break;
                    case 1:
                        mv.visitInsn(Opcodes.ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(Opcodes.ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(Opcodes.ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                        break;
                }
                mv.visitLdcInsn(parameter.getName());
                mv.visitMethodInsn(
                        Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                        "forName",
                        "(Ljava/lang/String;)Ljava/lang/Class;",
                        false
                );
                mv.visitInsn(Opcodes.AASTORE);
            }
        }

        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getMethod", "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;", false);
        mv.visitFieldInsn(Opcodes.PUTSTATIC, proxyClassName, fieldName, Type.getDescriptor(Method.class));
    }

    private static void generateASMMethod(Class<?> superClass, String proxyClassName, MethodVisitor mv, Method method1, int i) {
        Method method = method1;
        String asmFieldName = "_METHOD_ASM_" + method1.getName() + i;
        mv.visitLdcInsn(proxyClassName);
        mv.visitMethodInsn(Opcodes.INVOKESTATIC, Type.getInternalName(Class.class), "forName", "(Ljava/lang/String;)Ljava/lang/Class;", false);
        mv.visitLdcInsn("_asm_" + method.getName() + "_" + i);
        if (method.getParameterCount() == 0) {
            mv.visitInsn(Opcodes.ACONST_NULL);
        } else {
            switch (method.getParameterCount()) {

                case 1:
                    mv.visitInsn(Opcodes.ICONST_1);
                    break;
                case 2:
                    mv.visitInsn(Opcodes.ICONST_2);
                    break;
                case 3:
                    mv.visitInsn(Opcodes.ICONST_3);
                    break;
                default:
                    mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
                    break;
            }
            mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Class.class));
            for (int paramIndex = 0; paramIndex < method.getParameterTypes().length; paramIndex++) {
                Class<?> parameter = method.getParameterTypes()[paramIndex];
                mv.visitInsn(Opcodes.DUP);
                switch (paramIndex) {
                    case 0:
                        mv.visitInsn(Opcodes.ICONST_0);
                        break;
                    case 1:
                        mv.visitInsn(Opcodes.ICONST_1);
                        break;
                    case 2:
                        mv.visitInsn(Opcodes.ICONST_2);
                        break;
                    case 3:
                        mv.visitInsn(Opcodes.ICONST_3);
                        break;
                    default:
                        mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                        break;
                }
                mv.visitLdcInsn(parameter.getName());
                mv.visitMethodInsn(
                        Opcodes.INVOKESTATIC, Type.getInternalName(Class.class),
                        "forName",
                        "(Ljava/lang/String;)Ljava/lang/Class;",
                        false
                );
                mv.visitInsn(Opcodes.AASTORE);
            }
        }


        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Class.class), "getMethod", "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;", false);
        
        mv.visitFieldInsn(Opcodes.PUTSTATIC, proxyClassName, asmFieldName, Type.getDescriptor(Method.class));
    }

生成方法

private static void addSuperMethodImpl(ClassWriter cw, Class<?> superClass, String proxyClassName) {
    Method[] methods = getMethods(superClass);
    for (int i = 0; i < methods.length; i++) {
        Method method = methods[i];
        String asmMethodName = "_asm_" + method.getName() + "_" + i;
        String methodName = method.getName();
        String fieldName = "_METHOD_" + method.getName() + i;
        String asmFieldName = "_METHOD_ASM_" + methods[i].getName() + i;
        createSuperMethod(cw, superClass, method, asmMethodName);
        createProxyMethod(cw, proxyClassName, method, methodName, fieldName, asmFieldName);
    }
}
//生成代理方法,调用MethodInterceptor
private static void createProxyMethod(ClassWriter cw, String proxyClassName, Method method, String methodName, String fieldName, String asmFieldName) {
    MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, methodName, Type.getMethodDescriptor(method), null, new String[]{Type.getInternalName(Exception.class)});
    mv.visitCode();
    mv.visitVarInsn(Opcodes.ALOAD, 0);
    mv.visitFieldInsn(Opcodes.GETFIELD, proxyClassName, "methodInterceptor", Type.getDescriptor(MethodInterceptor.class));
    mv.visitVarInsn(Opcodes.ALOAD, 0);
    mv.visitFieldInsn(Opcodes.GETSTATIC, proxyClassName, fieldName, Type.getDescriptor(Method.class));
    switch (method.getParameterCount()) {
        case 0:
            mv.visitInsn(Opcodes.ICONST_0);
            break;
        case 1:
            mv.visitInsn(Opcodes.ICONST_1);
            break;
        case 2:
            mv.visitInsn(Opcodes.ICONST_2);
            break;
        case 3:
            mv.visitInsn(Opcodes.ICONST_3);
            break;
        default:
            mv.visitVarInsn(Opcodes.BIPUSH, method.getParameterCount());
            break;
    }
    mv.visitTypeInsn(Opcodes.ANEWARRAY, Type.getInternalName(Object.class));
    for (int paramIndex = 0; paramIndex < method.getParameterCount(); paramIndex++) {
        mv.visitInsn(Opcodes.DUP);
        switch (paramIndex) {
            case 0:
                mv.visitInsn(Opcodes.ICONST_0);
                break;
            case 1:
                mv.visitInsn(Opcodes.ICONST_1);
                break;
            case 2:
                mv.visitInsn(Opcodes.ICONST_2);
                break;
            case 3:
                mv.visitInsn(Opcodes.ICONST_3);
                break;
            default:
                mv.visitVarInsn(Opcodes.BIPUSH, paramIndex);
                break;
        }
        mv.visitVarInsn(Opcodes.ALOAD, paramIndex + 1);
        mv.visitInsn(Opcodes.AASTORE);
    }
    mv.visitFieldInsn(Opcodes.GETSTATIC, proxyClassName, asmFieldName, Type.getDescriptor(Method.class));
    mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type.getInternalName(MethodInterceptor.class), "intercept",
            "(Ljava/lang/Object;Ljava/lang/reflect/Method;[Ljava/lang/Object;Ljava/lang/reflect/Method;)Ljava/lang/Object;", true);
    addReturnWithCheckCast(mv, method.getReturnType());
    mv.visitMaxs(2, 2);
    mv.visitEnd();
}
//生成调用父类的方法
private static void createSuperMethod(ClassWriter cw, Class<?> superClass, Method method, String methodName) {
    MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC, methodName, Type.getMethodDescriptor(method), null, new String[]{Type.getInternalName(Throwable.class)});
    mv.visitCode();
    int parameterCount = method.getParameterCount();
    for (int index = 0; index <= parameterCount; index++) {
        mv.visitVarInsn(Opcodes.ALOAD, index);
    }
    mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type.getInternalName(superClass), method.getName(), Type.getMethodDescriptor(method), false);
    addReturnNoCheckCast(mv, method.getReturnType());
    mv.visitMaxs(2, 2);
    mv.visitEnd();
}

//添加方法返回,需要转类型
private static void addReturnWithCheckCast(MethodVisitor mv, Class<?> returnType) {
    if (returnType.isAssignableFrom(Void.class) || "void".equals(returnType.getName())) {
        mv.visitInsn(Opcodes.RETURN);
        return;
    }
    if (returnType.isAssignableFrom(boolean.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Boolean.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Boolean.class), "booleanValue", "()Z", false);
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(int.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Integer.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Integer.class), "intValue", "()I", false);
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(long.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Long.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Long.class), "longValue", "()J", false);
        mv.visitInsn(Opcodes.LRETURN);
    } else if (returnType.isAssignableFrom(short.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Short.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Short.class), "shortValue", "()S", false);
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(byte.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Byte.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Byte.class), "byteValue", "()B", false);
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(char.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Character.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Character.class), "charValue", "()C", false);
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(float.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Float.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Float.class), "floatValue", "()F", false);
        mv.visitInsn(Opcodes.FRETURN);
    } else if (returnType.isAssignableFrom(double.class)) {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(Double.class));
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type.getInternalName(Double.class), "doubleValue", "()D", false);
        mv.visitInsn(Opcodes.DRETURN);
    } else {
        mv.visitTypeInsn(Opcodes.CHECKCAST, Type.getInternalName(returnType));
        mv.visitInsn(Opcodes.ARETURN);
    }
}

//添加方法返回,不需要转类型
private static void addReturnNoCheckCast(MethodVisitor mv, Class<?> returnType) {
    if (returnType.isAssignableFrom(Void.class) || "void".equals(returnType.getName())) {
        mv.visitInsn(Opcodes.RETURN);
        return;
    }
    if (returnType.isAssignableFrom(boolean.class)) {
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(int.class)) {
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(long.class)) {
        mv.visitInsn(Opcodes.LRETURN);
    } else if (returnType.isAssignableFrom(short.class)) {
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(byte.class)) {
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(char.class)) {
        mv.visitInsn(Opcodes.IRETURN);
    } else if (returnType.isAssignableFrom(float.class)) {
        mv.visitInsn(Opcodes.FRETURN);
    } else if (returnType.isAssignableFrom(double.class)) {
        mv.visitInsn(Opcodes.DRETURN);
    } else {
        mv.visitInsn(Opcodes.ARETURN);
    }
}

生成字节码数组后,还需要自定义类加载器来加载类

  • 定义Map存储类名和字节数组的关系
  • 加载类前先调用add方法添加关系
  • 调用loadClass获取Class
package simple;

import java.util.HashMap;
import java.util.Map;

public class ASMClassLoader extends ClassLoader {
    private final Map<String, byte[]> classMap = new HashMap<>();

    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        if (classMap.containsKey(name)) {
            byte[] bytes = classMap.get(name);
            classMap.remove(name);
            return defineClass(name, bytes, 0, bytes.length);
        }
        return super.findClass(name);
    }

    public void add(String name, byte[] bytes) {
        classMap.put(name, bytes);
    }
}

生成类后,对比一下和cglib调用的效果:

package simple.test;

import simple.Enhancer;

public class Main {
    public static void main(String[] args) throws Throwable {
        System.out.println("cglib动态代理------------");
        net.sf.cglib.proxy.Enhancer cEnhancer = new net.sf.cglib.proxy.Enhancer();
        cEnhancer.setSuperclass(UserService.class);
        cEnhancer.setCallback(new CUserMethodInterceptor());
        UserService cUserService = (UserService) cEnhancer.create();
        System.out.println(cUserService.getClass().getName());
        System.out.println(cUserService.login("admin", "admin"));
        System.out.println(cUserService.login("admin", "admin1"));
        System.out.println("asm实现动态代理-----------");
        Enhancer enhancer = new Enhancer();
        enhancer.setSuperClass(UserService.class);
        enhancer.setMethodInterceptor(new UserMethodInterceptor());
        UserService service = (UserService) enhancer.create();
        System.out.println(service.getClass().getName());
        System.out.println(service.login("admin", "admin"));
        System.out.println(service.login("admin", "admin1"));
    }
}

image.png