ASM实现findViewById

633 阅读5分钟

设计目标:

定义一个控件,确保它的名称和xml中的id一致,然后通过ASM代码插入的方式实现控件查找。

与ButterKnife的不同之处是:

  1. 不需要每个控件变量增加注解声明id
  2. 只给需要查找控件的类增加一个注解
  3. 控件名称需要与xml中配置的id一致
  4. 类中需要增加一个initView()的空方法,并且在流程中调用
  5. 如果控件名称与id不一致,需要自己初始化

findViewById的逻辑:

View.findViewById:

public final <T extends View> T findViewById(@IdRes int id) {
    if (id == NO_ID) {
        return null;
    }
    return findViewTraversal(id);
}

protected <T extends View> T findViewTraversal(@IdRes int id) {
    if (id == mID) {
        return (T) this;
    }
    return null;
}

ViewGroup.findViewTraversal:

protected <T extends View> T findViewTraversal(@IdRes int id) {
    if (id == mID) {
        return (T) this;
    }

    final View[] where = mChildren;
    final int len = mChildrenCount;

    for (int i = 0; i < len; i++) {
        View v = where[i];

        if ((v.mPrivateFlags & PFLAG_IS_ROOT_NAMESPACE) == 0) {
            v = v.findViewById(id);

            if (v != null) {
                return (T) v;
            }
        }
    }

    return null;
}

R.class:

aapt编译资源后会生成Ranim.class,Rattr.class, R$id.class...

插件工作流程

  • 扫描所有的类,寻找R$id.class,生成id与值的映射表
  • 扫描带注解的类,生成成员变量与控件id的映射,并进行代码插入
class BindTransform extends Transform {
    Project project
    CodeScanTool codeScan = new CodeScanTool()
    CodeInsertTool codeInsert = new CodeInsertTool()

    BindTransform(Project project) {
        super()
        this.project = project
    }

    @Override
    String getName() {
        return "bind-view"
    }

    @Override
    Set<QualifiedContent.ContentType> getInputTypes() {
        return TransformManager.CONTENT_CLASS
    }

    @Override
    Set<? super QualifiedContent.Scope> getScopes() {
        return TransformManager.PROJECT_ONLY
    }

    @Override
    boolean isIncremental() {
        return false
    }

    @Override
    void transform(TransformInvocation transformInvocation) throws TransformException, InterruptedException, IOException {
        transformInvocation.getOutputProvider().deleteAll()
        codeScan.init()
        println(transformInvocation.context.projectName + transformInvocation.context.variantName)
        //扫描文件,生成R文件的信息
        transformInvocation.inputs.forEach { TransformInput input ->
            input.jarInputs.forEach {
                JarInput jarInput ->
                    codeScan.scanJar(jarInput)
            }
            input.directoryInputs.forEach {
                directoryInput ->
                    directoryInput.file.eachFileRecurse { File file ->
                        if (file.isFile()) {
                            codeScan.scanClass(file)
                        }
                    }
            }
        }
//        codeScan.print()
        codeInsert.mPackageRInfo = codeScan.mPackageRinfo
        //代码植入
        transformInvocation.inputs.forEach { TransformInput input ->
            input.jarInputs.forEach {
                JarInput jarInput ->
                    codeInsert.insertCodeToJar(jarInput)
                    codeInsert.copyJarInput(jarInput, transformInvocation)
            }
            input.directoryInputs.forEach {
                directoryInput ->
                    directoryInput.file.eachFileRecurse { File file ->
                        if (file.isFile()) {
                            codeInsert.insertCodeToClass(file)
                        }
                    }
                    codeInsert.copyDirectoryInput(directoryInput, transformInvocation)
            }
        }
    }
}

扫描的实现:

class CodeScanTool {
    PackageRInfo mPackageRinfo = new PackageRInfo()

    void init() {
        mPackageRinfo.init()
    }

    void print() {
        mPackageRinfo.print()
    }

    void scanJar(JarInput jarInput) {
        scanJar(jarInput.file)
    }

    boolean scanJar(File jarFile) {
        //检查是否存在缓存,有就添加class list 和 设置fileContainsInitClass
        if (!jarFile)
            return false

        def file = new JarFile(jarFile)
        Enumeration enumeration = file.entries()

        while (enumeration.hasMoreElements()) {
            JarEntry jarEntry = (JarEntry) enumeration.nextElement()
            String entryName = jarEntry.getName()
            //support包不扫描
            if (entryName.startsWith("android/support") || entryName.startsWith("androidx."))
                break
            InputStream inputStream = file.getInputStream(jarEntry)
            scanClass(inputStream, entryName)
            inputStream.close()
        }
        if (null != file) {
            file.close()
        }
        return true
    }

    void scanClass(File file) {
        FileInputStream is = new FileInputStream(file)
        scanClass(is, file.name)
        is.close()
    }

    //search R.class, @BindView
    void scanClass(InputStream inputStream, String entryName) {
        String pattern = "(.+)/R\\\$(.+).class"
        Pattern p = Pattern.compile(pattern)
        Matcher matcher = p.matcher(entryName)
        println("scanclass name:" + entryName)
        //扫描R.java的静态内部类:如R$id.class
        if (matcher.matches()) {
            ClassReader cr = new ClassReader(inputStream)
            ClassWriter cw = new ClassWriter(cr, 0)
            ScanRinfoClassVisitor cv = new ScanRinfoClassVisitor(Opcodes.ASM5, cw, matcher.group(1), matcher.group(2))
            cr.accept(cv, ClassReader.EXPAND_FRAMES)
        }
        //寻找应用包名
        if ("BuildConfig.class" == entryName) {
            ClassReader cr = new ClassReader(inputStream)
            ClassWriter cw = new ClassWriter(cr, 0)
            ScanAppIdClassVisitor cv = new ScanAppIdClassVisitor(Opcodes.ASM5, cw)
            cr.accept(cv, ClassReader.EXPAND_FRAMES)
        }
    }

    private class ScanRinfoClassVisitor extends ClassVisitor {
        String pkg;
        String type;

        ScanRinfoClassVisitor(int api, ClassVisitor classVisitor, String pkg, String type) {
            super(api, classVisitor)
            this.pkg = pkg
            this.type = type
        }

        @Override
        FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
            if (descriptor == 'I') {
                //获取id
                mPackageRinfo.putValue(pkg, type, name, (Integer) value)
            }
            return super.visitField(access, name, descriptor, signature, value)
        }
    }

    private class ScanAppIdClassVisitor extends ClassVisitor {
        ScanAppIdClassVisitor(int api, ClassVisitor classVisitor) {
            super(api, classVisitor)
        }

        @Override
        FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
            //没有兼容library中使用LIBRARY_PACKAGE_NAME
            if (name == "APPLICATION_ID") {
                mPackageRinfo.mAppId = value
            }
            return super.visitField(access, name, descriptor, signature, value)
        }
    }
}
  1. 扫描R.java的静态内部类,如Rid.class,Rattr.class等,生成id映射表
  2. 寻找应用包名,最终需要根据控件所在的包,来寻找对应的id

代码植入的实现:

class CodeInsertTool {
    PackageRInfo mPackageRInfo;

    void insertCodeToClass(File file) {
        if (!file.name.endsWith(".class")) {
            return
        }
        println("insertCodeToClass  name:" + file.name + "  path:" + file.path)

        def optClass = new File(file.getParent(), file.name + ".opt")
        if (optClass.exists()) {
            optClass.delete()
        }
        FileInputStream inputStream = new FileInputStream(file)
        FileOutputStream outputStream = new FileOutputStream(optClass)

        def bytes = generateCode(inputStream, file.name)
        outputStream.write(bytes)
        inputStream.close()
        outputStream.close()
        if (file.exists()) {
            file.delete()
        }
        optClass.renameTo(file)
    }

    void insertCodeToJar(JarInput jarInput) {
        def jarFile = jarInput.file
        def optJar = new File(jarFile.getParent(), jarFile.name + ".opt")
        if (optJar.exists())
            optJar.delete()
        def file = new JarFile(jarFile)
        Enumeration enumeration = file.entries()
        JarOutputStream jarOutputStream = new JarOutputStream(new FileOutputStream(optJar))

        while (enumeration.hasMoreElements()) {
            JarEntry jarEntry = (JarEntry) enumeration.nextElement()
            String entryName = jarEntry.getName()
            ZipEntry zipEntry = new ZipEntry(entryName)
            InputStream inputStream = file.getInputStream(jarEntry)
            jarOutputStream.putNextEntry(zipEntry)
            def bytes = generateCode(inputStream, entryName)
            jarOutputStream.write(bytes)
            inputStream.close()
            jarOutputStream.closeEntry()
        }
        jarOutputStream.close()
        file.close()

        if (jarFile.exists()) {
            jarFile.delete()
        }
        optJar.renameTo(jarFile)
    }

    byte[] generateCode(InputStream is, String name) {
        println("generateCode name:" + name)
        ClassReader cr = new ClassReader(is)
        ClassWriter cw = new ClassWriter(cr, 0)

        BindViewClassVisitor cv = new BindViewClassVisitor(Opcodes.ASM5, cw)
        cv.mPackageRInfo = mPackageRInfo
        cr.accept(cv, ClassReader.EXPAND_FRAMES)
        return cw.toByteArray()
    }

    void copyJarInput(JarInput jarInput, TransformInvocation transformInvocation) {
        File dest = getDestFile(jarInput, transformInvocation.outputProvider)
        FileUtils.copyFile(jarInput.file, dest)
    }

    File getDestFile(JarInput jarInput, TransformOutputProvider outputProvider) {
        def destName = jarInput.name
        // 重名名输出文件,因为可能同名,会覆盖
        def hexName = DigestUtils.md5Hex(jarInput.file.absolutePath)
        if (destName.endsWith(".jar")) {
            destName = destName.substring(0, destName.length() - 4)
        }
        // 获得输出文件
        File dest = outputProvider.getContentLocation(destName + "_" + hexName, jarInput.contentTypes, jarInput.scopes, Format.JAR)
        return dest
    }

    void copyDirectoryInput(DirectoryInput directoryInput, TransformInvocation transformInvocation) {
        File dest = transformInvocation.outputProvider.getContentLocation(directoryInput.name,
                directoryInput.contentTypes, directoryInput.scopes, Format.DIRECTORY)
        FileUtils.copyDirectory(directoryInput.file, dest)
    }
}
  1. 扫描类是否有@BindView的注解,如果没有,说明无需代码植入
  2. 扫描该类中的成员变量,并且判断该变量名是否在id映射表中
  3. 如果存在,说明是一个控件,则记录到map中
  4. 扫描方法,在initView()中进行代码植入,类似于view = findViewById(id)
class BindViewClassVisitor extends ClassVisitor {
    boolean hasBindView = false
    boolean isKotlin = false
    String pkgname = null
    PackageRInfo mPackageRInfo
    Map<String, FieldInfo> map = new HashMap<>()
    String className = null

    class FieldInfo {
        int id
        String desc
        String clz

        FieldInfo(int id, String desc) {
            this.id = id
            this.desc = desc
            clz = desc.substring(1, desc.length() - 1)
        }
    }

    void setPkg(String pkg) {
        if (pkg == null || pkg.length() == 0) {
            pkgname = mPackageRInfo.mAppId
        } else {
            pkgname = pkg
        }
    }

    BindViewClassVisitor(int api, ClassVisitor classVisitor) {
        super(api, classVisitor)
    }

    @Override
    void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
        super.visit(version, access, name, signature, superName, interfaces)
        className = name
    }

    @Override
    AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
        if (descriptor == "Lcom/example/annotation/BindView;") {
            return new BindViewAnnotationVisitor(Opcodes.ASM6, super.visitAnnotation(descriptor, visible))
        } else if (descriptor == "Lkotlin/Metadata;") {
            isKotlin = true
            return super.visitAnnotation(descriptor, visible)
        } else {
            return super.visitAnnotation(descriptor, visible)
        }
    }

    //find all fields
    @Override
    FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
        if (hasBindView && mPackageRInfo.hasValue(pkgname, "id", name)) {
            map.put(name, new FieldInfo(mPackageRInfo.getValue(pkgname, "id", name), descriptor))
        }
        return super.visitField(access, name, descriptor, signature, value)
    }

    //insert code to initView()
    @Override
    MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
        if (hasBindView && name == "initView") {
            return new InitViewMethodVisitor(Opcodes.ASM6, descriptor.startsWith("()"), super.visitMethod(access, name, descriptor, signature, exceptions))
        } else {
            return super.visitMethod(access, name, descriptor, signature, exceptions)
        }
    }

    class BindViewAnnotationVisitor extends AnnotationVisitor {

        BindViewAnnotationVisitor(int api, AnnotationVisitor annotationVisitor) {
            super(api, annotationVisitor)
        }

        @Override
        void visitEnd() {
            super.visitEnd()
            if (!hasBindView) {
                hasBindView = true
                setPkg(null)
            }
        }

        @Override
        void visit(String name, Object value) {
            super.visit(name, value)
            if (name == "pkgname") {
                hasBindView = true
                setPkg(value)
            }
        }
    }

    class InitViewMethodVisitor extends MethodVisitor {
        boolean noparam = false

        InitViewMethodVisitor(int api, boolean noparam, MethodVisitor methodVisitor) {
            super(api, methodVisitor)
            this.noparam = noparam
        }

        @Override
        void visitCode() {
            super.visitCode()
            map.forEach(new BiConsumer<String, FieldInfo>() {
                @Override
                void accept(String s, FieldInfo fieldInfo) {
                    mv.visitVarInsn(Opcodes.ALOAD, 0)
                    if (noparam) {
                        mv.visitVarInsn(Opcodes.ALOAD, 0)
                    } else {
                        mv.visitVarInsn(Opcodes.ALOAD, 1)
                    }

                    mv.visitLdcInsn(fieldInfo.id)
                    mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, className, "findViewById", "(I)Landroid/view/View;", false)
                    mv.visitTypeInsn(Opcodes.CHECKCAST, fieldInfo.clz)
                    mv.visitFieldInsn(Opcodes.PUTFIELD, className, s, fieldInfo.desc)
                }
            })
        }
    }
}

举例:

@BindView
public class Main extends Activity {
    TextView text;
    Button btn;
    TextView btm;

    @Override
    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.main_test);
        initView();
    }

    @Override
    protected void onResume() {
        super.onResume();
        text.setText("what a happy boy");
    }

    void initView() {
        btm = findViewById(R.id.bottom_text);
    }
}

生成的类:——在initView()中进行了控件的初始化

@BindView
public class Main extends Activity {
    TextView text;
    Button btn;
    TextView btm;

    public Main() {
    }

    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        this.setContentView(2131361822);
        this.initView();
    }

    protected void onResume() {
        super.onResume();
        this.text.setText("what a happy boy");
    }

    void initView() {
        this.text = (TextView)this.findViewById(2131165351);
        this.btn = (Button)this.findViewById(2131165251);
        this.btm = (TextView)this.findViewById(2131165250);
    }
}