爆肝源码系列——Java序列化解析 (一)

687 阅读7分钟

“这是我参与8月更文挑战的第2天,活动详情查看: 8月更文挑战” juejin.cn/post/698796…

Java如何得到 suid 的?

了解过序列化的同学肯定知道,suid 在反序列化时侯有两个,一个是本地类计算出来的 suid,一个是从序列化字节流中读取的 suid,我们的分析也是针对这两种进行展开。

1. 从异常抛出开始

为了寻找源码分析的思路,我们可以通过引发异常开始

1.1 创建定义一个可序列化的类

public class User implements Serializable {
    public String name;
    private Integer age;

    // getter & setter & toString
    ...
}

1.2 运行序列化方法

static void test1() throws IOException {
    User user = new User("player", 21);
    user.setName("player");
    user.setAge(21);
    System.out.println(user);
    FileSerializeUtil.serialize(user, "user.obj");
}
public static void serialize(Object object, String filepath) throws IOException {
    try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(filepath))) {
        objectOutputStream.writeObject(object);
    } catch (IOException e) {
        logger.warning("serialize: fail to serialize object to " + filepath);
        throw e;
    }
}

1.3 做点手脚

添加一个成员变量到 User

public class User implements Serializable {
    public String name;
    private Integer age;
    private String address;

    ...
}

1.4 运行反序列化方法

static void test2() throws IOException, ClassNotFoundException {
    User user = (User) FileSerializeUtil.deserialize("user.obj");
    System.out.println(user);
}
public static Object deserialize(String filepath) throws IOException, ClassNotFoundException {
    try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(filepath))) {
        return objectInputStream.readObject();
    } catch (IOException | ClassNotFoundException e) {
        logger.warning("deserialize: file to deserialize object from " + filepath);
        throw e;
    }
}

1.5 触发异常,查看异常信息

Exception in thread "main" java.io.InvalidClassException: serial.User; local class incompatible: stream classdesc serialVersionUID = 8985745470054656491, local class serialVersionUID = -4967160969146043535
	at java.base/java.io.ObjectStreamClass.initNonProxy(ObjectStreamClass.java:715)
	...

由此我们可以从 715 行开始我们的故事

2. 主干源码分析

2.1 initNonProxy()分析

为了确定本地的 serialVersionUID 和来自序列化字节流的 serialVersionUID 是如何获得的,在 ObjectStreamClass 类中我们找到了抛出异常的地方,在 initNonProxy() 有这样一段判断代码,

void initNonProxy(ObjectStreamClass model,
                      Class<?> cl,
                      ClassNotFoundException resolveEx,
                      ObjectStreamClass superDesc)
        throws InvalidClassException
    {
    	// (1) suid 的值在这里获得的
        long suid = Long.valueOf(model.getSerialVersionUID());
    	...
        if (cl != null) {
            ...
            // 异常在这里抛出
            if (model.serializable == osc.serializable &&
                    !cl.isArray() && !isRecord(cl) &&
                	// (2) osc 也是通过 getSerialVersionUID() 得到序列化
                    suid != osc.getSerialVersionUID()) {
                throw new InvalidClassException(osc.name,
                        "local class incompatible: " +
                                "stream classdesc serialVersionUID = " + suid +
                                ", local class serialVersionUID = " +
                                osc.getSerialVersionUID());
            }
            ...
        }
}

异常产生的原因在于 suidosc.getSerialVersionUID() 不相等。而我们可以知道,两者本质上都调用了 getSerialVersionUID() 方法,我们断定该方法中一定有某种判断。下面我们对 getSerialVersionUID() 方法展开分析。

2.2 getSerialVersionUID() 分析

首先我们注意到方法的文档注释第一行,

Return the serialVersionUID for this class.

很显然,告诉我们这个方法返回的类的 serialVersionUID 。我们在源码中 debug 也可以得到 computeDefaultSUID() 返回的即为本地类的 suid

该方法的源代码如下,

public long getSerialVersionUID() {
    // REMIND: synchronize instead of relying on volatile?
    if (suid == null) {
        if (isRecord)
            return 0L;

        suid = AccessController.doPrivileged(
            new PrivilegedAction<Long>() {
                public Long run() {
                    return computeDefaultSUID(cl);
                }
            }
        );
    }
    return suid.longValue();
}

为了更加清楚发现两者调用的不同,我们 debug 看看,

对于(1)处 suid

对于(2)处 osc.getSerialVersionUID() 则有两种可能,

  • User 类中没有自定义 serialVersionUID()
  • User 类中自定义了 serialVersionUID()

我们很容易理解就是本地序列化类的 UID反序列化字节流 UID 都是通过这个方法来获得的。对于没有自定义的本地的 UID将会走 if 中的 computeDefaultSUID() 方法进行计算类的 SUID,对于自定义 SUID 和字节流中的 SUID,将会直接返回。

关于 computeDefaultSUID() 的详解可以阅读下文,现在我们先继续讲主线内容

那么我们会有疑问,既然不走 computeDefaultSUID(),那么自定义 SUID 和字节流中的 SUID 是从哪得到的呢?我们通过 IDEA 强大的检索功能,我们可以看看引用了 suid 的地方,

  • 蓝色标注的是我们没有自定义 suid 类的计算路径
  • 目测红色 getDeclaredSUID() 是获取自定义(声明)的 suid
  • 红色 in.readLong() 是从文件输入流中获取的

下面我们分析获取自定义声明的 suid 来印证我们的猜测。

in.readLong() 部分我们不做分析,因为继续进入将会进入 I/O 部分的 Java 类,这里简单理解就是通过输入流从文件中读取一个 long 类型整数。

2.3 获取自定义(声明的)suid

2.3.1 ObjectStreamClass() 分析

通过 IDEA 跳转,我们来到了调用 getDeclaredSUID() 的地方,看了看这个方法的注释,

Creates local class descriptor representing given class.

大概的意思是,创建一个对应类的描述符。其他的细节我们不需要过分关注,我们只需要关注 if ... else 条件,

...
// 判断是否实现 Serializable 接口    
serializable = Serializable.class.isAssignableFrom(cl);
...
if (serializable) {
    AccessController.doPrivileged(new PrivilegedAction<>() {
        public Void run() {
            ...
            suid = getDeclaredSUID(cl);
            ...
        }
    });
} else {
    // 如果没有实现 suid = 0
    suid = Long.valueOf(0);
    ...
}

不难看出,对于没有实现 Serializable 的类来说,默认分配的 suid 是 0,对于实现了的类来说,走 getDeclaredSUID() 得到 suid

2.3.2 getDeclaredSUID() 分析
private static Long getDeclaredSUID(Class<?> cl) {
    try {
        // 通过反射得到成员变量 serialVersionUID
        Field f = cl.getDeclaredField("serialVersionUID");
        // 检查成员变量 serialVersionId 是否是静态常量
        int mask = Modifier.STATIC | Modifier.FINAL;
        if ((f.getModifiers() & mask) == mask) {
            f.setAccessible(true);
            // 真正得到本地类自定义 suid 的地方
            return Long.valueOf(f.getLong(null));
        }
    } catch (Exception ex) {
    }
    return null;
}

这个方法代码并不多,首先通过反射得到成员变量 serialVersionUID,判断其是否是静态常量,然后通过反射得到它的值,就这么简单。看吧,获取自定义(声明)的 suid 已经分析完毕了。

3. 默认计算类 suid 的方法 computeDefaultSUID()

这个方法是在类没有自定义 serialVersionUID 的情况下用于根据类的“信息”来计算 suid 的方法

private static long computeDefaultSUID(Class<?> cl) {
    // 检查是否是代理对象和是否是 Serializable 的子类或子接口
    if (!Serializable.class.isAssignableFrom(cl) || Proxy.isProxyClass(cl))
    {
        return 0L;
    }

    try {
        // 将输出缓冲到字节数组中
        ByteArrayOutputStream bout = new ByteArrayOutputStream();
        // 这里我们可以理解为 dout 将输出缓存到 bout
        DataOutputStream dout = new DataOutputStream(bout);

        // 类名写入字节数组
        dout.writeUTF(cl.getName());

        // 获取类名修饰符
        int classMods = cl.getModifiers() &
            (Modifier.PUBLIC | Modifier.FINAL |
             Modifier.INTERFACE | Modifier.ABSTRACT);

        // 获取类成员方法
        Method[] methods = cl.getDeclaredMethods();
        
        if ((classMods & Modifier.INTERFACE) != 0) {
            // 如果类是接口
            classMods = (methods.length > 0) ?
                // 如果存在方法就与抽象修饰符进行或运算
                (classMods | Modifier.ABSTRACT) :
            	// 不存在方法就与抽象修饰符的反码进行与运算
                (classMods & ~Modifier.ABSTRACT);
        }
        // 类修饰符写入字节数组
        dout.writeInt(classMods);

        if (!cl.isArray()) {
            // 补偿对于数组的处理,对于数组类型将会得到 Cloneable 和 Serializable,所以数组不必走进来
            Class<?>[] interfaces = cl.getInterfaces();
            String[] ifaceNames = new String[interfaces.length];
            for (int i = 0; i < interfaces.length; i++) {
                ifaceNames[i] = interfaces[i].getName();
            }
            // 对接口名进行排序,避免对于同样的接口数组产生不同的写入
            Arrays.sort(ifaceNames);
            for (int i = 0; i < ifaceNames.length; i++) {
                // 接口名写入字节数组
                dout.writeUTF(ifaceNames[i]);
            }
        }

        // 获取成员变量
        Field[] fields = cl.getDeclaredFields();
        MemberSignature[] fieldSigs = new MemberSignature[fields.length];
        // 获取成员变量的签名
        for (int i = 0; i < fields.length; i++) {
            fieldSigs[i] = new MemberSignature(fields[i]);
        }
        Arrays.sort(fieldSigs, new Comparator<>() {
            public int compare(MemberSignature ms1, MemberSignature ms2) {
                // 按照成员变量名进行排序
                return ms1.name.compareTo(ms2.name);
            }
        });
        
        // 对成员变量进行处理
        for (int i = 0; i < fieldSigs.length; i++) {
            MemberSignature sig = fieldSigs[i];
            // 得到成员变量修饰符
            int mods = sig.member.getModifiers() &
                (Modifier.PUBLIC | Modifier.PRIVATE | Modifier.PROTECTED |
                 Modifier.STATIC | Modifier.FINAL | Modifier.VOLATILE |
                 Modifier.TRANSIENT);
            // 非私有则进行写入
            if (((mods & Modifier.PRIVATE) == 0) ||
                // 如果是 static 或 transient 则写入
                ((mods & (Modifier.STATIC | Modifier.TRANSIENT)) == 0))
                /*
                 * 也就是说,在这里如果是非私有则一定写入,如果是私有但满足 static 或 transient 才写入
                 */
            {
                // 签名信息写入字节数组
                dout.writeUTF(sig.name);
                dout.writeInt(mods);
                dout.writeUTF(sig.signature);
            }
        }

        // 静态类则写入
        if (hasStaticInitializer(cl)) {
            dout.writeUTF("<clinit>");
            dout.writeInt(Modifier.STATIC);
            dout.writeUTF("()V");
        }

        // 获取构造方法
        Constructor<?>[] cons = cl.getDeclaredConstructors();
        MemberSignature[] consSigs = new MemberSignature[cons.length];
        // 获取构造方法签名
        for (int i = 0; i < cons.length; i++) {
            consSigs[i] = new MemberSignature(cons[i]);
        }
        // 对构造方法签名进行排序
        Arrays.sort(consSigs, new Comparator<>() {
            public int compare(MemberSignature ms1, MemberSignature ms2) {
                return ms1.signature.compareTo(ms2.signature);
            }
        });
        for (int i = 0; i < consSigs.length; i++) {
            MemberSignature sig = consSigs[i];
            // 获取构造方法修饰符
            int mods = sig.member.getModifiers() &
                (Modifier.PUBLIC | Modifier.PRIVATE | Modifier.PROTECTED |
                 Modifier.STATIC | Modifier.FINAL |
                 Modifier.SYNCHRONIZED | Modifier.NATIVE |
                 Modifier.ABSTRACT | Modifier.STRICT);
            // 如果是非私有构造方法则进行写入
            if ((mods & Modifier.PRIVATE) == 0) {
                dout.writeUTF("<init>");
                dout.writeInt(mods);
                dout.writeUTF(sig.signature.replace('/', '.'));
            }
        }

        MemberSignature[] methSigs = new MemberSignature[methods.length];
        // 获取方法的签名
        for (int i = 0; i < methods.length; i++) {
            methSigs[i] = new MemberSignature(methods[i]);
        }
        // 方法签名排序
        Arrays.sort(methSigs, new Comparator<>() {
            public int compare(MemberSignature ms1, MemberSignature ms2) {
                int comp = ms1.name.compareTo(ms2.name);
                if (comp == 0) {
                    comp = ms1.signature.compareTo(ms2.signature);
                }
                return comp;
            }
        });
        for (int i = 0; i < methSigs.length; i++) {
            MemberSignature sig = methSigs[i];
            // 获取方法修饰符
            int mods = sig.member.getModifiers() &
                (Modifier.PUBLIC | Modifier.PRIVATE | Modifier.PROTECTED |
                 Modifier.STATIC | Modifier.FINAL |
                 Modifier.SYNCHRONIZED | Modifier.NATIVE |
                 Modifier.ABSTRACT | Modifier.STRICT);
            // 非私有则写入
            if ((mods & Modifier.PRIVATE) == 0) {
                dout.writeUTF(sig.name);
                dout.writeInt(mods);
                dout.writeUTF(sig.signature.replace('/', '.'));
            }
        }

        // 刷新,将结果保存到字节数组
        dout.flush();

        // 对之前字节数组进行 SHA 运算
        MessageDigest md = MessageDigest.getInstance("SHA");
        byte[] hashBytes = md.digest(bout.toByteArray());
        long hash = 0;
        for (int i = Math.min(hashBytes.length, 8) - 1; i >= 0; i--){
            hash = (hash << 8) | (hashBytes[i] & 0xFF);
        }
        return hash;
    } catch (IOException ex) {
        throw new InternalError(ex);
    } catch (NoSuchAlgorithmException ex) {
        throw new SecurityException(ex.getMessage());
    }
}

这个方法看上去很长,其实做的事情是重复且容易理解,简单来说就是通过反射获取这个类的各种信息,将它们放到一个字节数组中,然后使用 hash 函数(SHA)进行运算得到一个代表类的“摘要”。这里作者已经充分注释了整个方法,读者可以通过粗略阅读来理解这个方法的具体细节。

hash 函数是一种常用于加密或者生成信息摘要的方法,其特点主要有

  • 任意输入,固定输出
  • 防碰撞,也叫做差之毫厘居距之千里,输入哪怕修改了 1 个位,计算出来结果也会发生大变化
  • 单向性,其实也是任意输入,固定输出所决定的,通过 hash 值反向计算出原来的输入计算上是不可行的

如果读者对这部分感兴趣,待作者开“区块链”专栏为读者一一介绍“加密的那些事情”

3.1 决定类 suid 的因素

上面我们知道 hash 函数是产生这个 suid 的关键,那么要找决定 suid 的因素,首先寻找决定 hash 的因素,也就是 hash 的输入,

MessageDigest md = MessageDigest.getInstance("SHA");
// 输入来源
byte[] hashBytes = md.digest(bout.toByteArray());
long hash = 0;
for (int i = Math.min(hashBytes.length, 8) - 1; i >= 0; i--) {
    hash = (hash << 8) | (hashBytes[i] & 0xFF);
}

我们看到输入来自于 bout 变量,也就是我们要密切关注与下面俩个变量的操作,

// 缓存 hash 输入的对象
ByteArrayOutputStream bout = new ByteArrayOutputStream();
// 方法中往 bout 写入的对象
DataOutputStream dout = new DataOutputStream(bout);

我们只要找到与 bout 有关的动作就可以,上面作者已经将完整的代码注释了,读者可以阅读后看下面总结因素及引起 hash 值变化的具体动作,

因素具体动作
类名修改类名
类修饰符增加、减少和修改类修饰符
类接口增加、减少和实现接口
类成员方法和构造方法增加和减少方法;修改方法签名
类成员变量(包括静态、常量)增加和减少变量;修改变量签名

我们发现除了类所继承的类并不影响 suid 之外,其他类信息的变动都将修改类的 suid

其实这些因素里面还有一些不同的细节,作者已经在源码中标注了,比如对于私有构造方法,并不写入到 hash 输入中。

至于为什么当初的编写者不将类继承算进去,作者还在思考中

3.1 为什么要进行 Sort?

在阅读这个方法时,我们不难发现一些地方经常用到 Array.sort() 方法对各种反射得到的变量进行排序,比如下面的这段代码,

Arrays.sort(fieldSigs, new Comparator<>() {
    public int compare(MemberSignature ms1, MemberSignature ms2) {
        return ms1.name.compareTo(ms2.name);
    }
});

这段代码是对反射得到的类成员变量签名进行排序的方法,其实不难理解,目的就是成员变量之间的位置变换不应该影响一个类的 suid,比如我将 User 的两个成员变量位置互换,反序列化也不会出现异常。

public class User implements Serializable {
    private Integer age;
    public String name;

    ...
}

总结

至此我们的序列化源码分析就结束啦,这是作者花了非常大时间去创作的一篇源码分析,如果你喜欢我这样的新人的话,不妨点个赞。如果你觉得我有些地方可以改进,非常希望在评论区看到你哟!