本章涵盖以下内容:
- 定义并计算 precision、recall,以及 true / false positives / negatives
- 比较 F1 score 与其他质量指标
- 通过数据平衡与数据增强来减少过拟合
- 使用 TensorBoard 绘制质量指标曲线
上一章结束时,我们陷入了一个尴尬局面。虽然深度学习项目的基本机制已经搭起来了,但结果却一点都不实用;网络只是把所有东西都判成了“非结节”!更糟的是,从表面上看结果似乎还很好,因为我们当时盯着的是“训练集和验证集中被正确分类的总体百分比”。可我们的数据极度偏向负样本,于是模型只要盲目地把一切都判成负类,就能很轻松地拿到高分。可惜,这么做会让模型几乎毫无用处!
这意味着,我们此时仍然关注图 14.1 中与第 13 章相同的那一部分。但现在,我们要解决的不再只是“让分类模型能跑起来”,而是“让分类模型真正跑得好”。本章的主题,就是如何去度量、量化、表达,并进一步改进模型完成任务的效果。
图 14.1 我们的端到端肺癌检测项目,本章聚焦于其中的主题:第 3 步,分类
14.1 改进的高层计划
虽然有些抽象,但图 14.2 展示了我们将如何处理这一整组问题。我们先细致地走一遍这张稍显抽象的“章节地图”。这一章中,我们要面对的问题包括:过度聚焦某一个狭窄指标,结果导致模型行为在整体意义上毫无价值。为了让本章的一些概念更具体些,我们会先借助一个比喻,把这些麻烦放到更容易理解的场景里——看门狗(1)、鸟与窃贼(2),如图 14.2 所示。
图 14.2 我们将借助这些比喻,来修改用于衡量模型表现的指标,让模型真正“支棱起来”
在那之后,我们会发展出一套图形化语言,用来表达正式讨论上一章实现问题时所需的一些核心概念——比例:recall 和 precision(3),如图 14.2 所示。等这些概念建立起来之后,我们还会引入一点相关数学,用这些概念把一种更稳健的模型评分方式压缩成一个单一数值——新的指标:F1 score(4)。我们会实现这些新指标的公式,并观察它们在训练过程中如何随 epoch 推进而变化。最后,我们会对 LunaDataset 做一些迫切需要的修改,以改善训练效果——平衡(5)和增强(6)。然后,我们再看看这些实验性修改是否真的对性能指标产生了预期效果。
等这一章结束时,我们训练出来的模型会好得多——真正“workin’ great!”(7),如图 14.2 所示。虽然它还不能直接拿去临床使用,但它已经能够给出明显优于随机猜测的结果。也就是说,我们终于有了整个“结节候选分类”大计划中第 3 步的一个可用实现;等这一步完成后,我们就可以开始考虑如何把第 2 步(segmentation,分割)也纳入到整个项目中来。
14.2 好狗与坏家伙:False positives 与 false negatives
为了开始这个比喻,我们先把“模型”和“肿瘤”暂时放一边,来看看图 14.3 中两只刚从服从训练学校毕业的看门狗。它们都想帮我们发现入室盗贼——这是一种少见但严重、而且必须尽快处理的情况。
图 14.3 本章的一组主题,此处聚焦于引导性比喻
可惜,虽然两只狗都是好孩子,但都算不上好看门狗。我们的梗犬(Chirpy)几乎对什么都叫,而年迈的猎犬(Dozer)几乎只对盗贼叫——但前提是,它在盗贼到来时刚好醒着。
Chirpy 几乎每次都能提醒我们有盗贼。可它也会对消防车、雷暴、直升机、鸟、邮递员、松鼠、路人等等发出警报。如果我们每次听到它叫都去查看,那我们几乎不会被偷(只有最鬼鬼祟祟的小偷才能混过去)。完美!——只是,这么勤快地响应每一次狗叫,其实根本没有替我们节省什么工作。相反,我们会隔几个小时就爬起来,手拿手电筒,因为 Chirpy 闻到了猫、听见了猫头鹰,或者看到一辆深夜公交经过。Chirpy 的问题在于:false positives 太多。
所谓 false positive(假阳性) ,是指某个事件被分类成“值得关注”或“属于目标类别”(这里 positive 的意思是“是的,这正是我想知道的那类东西”),但实际上它并不值得关注。对于结节检测问题来说,这意味着:一个其实无关紧要的候选区域,却被标记成了结节,因此需要放射科医生进一步关注。对于 Chirpy 来说,消防车、雷暴等等都属于这种情况。下一节以及本章后续图中,我们会用一张“猫”的图片作为典型的 false positive。
与 false positive 相对的是 true positive(真阳性) :真正值得关注、并且被正确分类出来的对象。图中,我们会用一个“人类窃贼”的形象来表示 true positive。
另一方面,如果 Dozer 叫了,那就赶紧报警吧,因为这几乎意味着:真的有人闯进来了,或者房子着火了,或者哥斯拉来袭了。问题是,Dozer 睡得太沉了,进行中的入室盗窃那点动静,多半根本吵不醒它,于是只要真的有人来偷,我们几乎总还是会被偷。还是那句话:虽然总比什么都没有强一点,但这并没有真正给我们当初养狗时想要的那种安心感。Dozer 的问题在于:false negatives 太多。
所谓 false negative(假阴性) ,是指某个事件被分类为“不值得关注”或“不属于目标类别”(这里 negative 的意思是“不是,那不是我想知道的那类东西”),但实际上它却恰恰是值得关注的。对于结节检测问题来说,这意味着:一个结节(也就是潜在癌症)没有被检测出来。对于 Dozer 来说,这些就是它睡过去的那些盗窃。这里我们会稍微发挥一点想象力,用一只“啮齿类盗贼”的图片来表示 false negative——因为它们够鬼祟!
与 false negative 相对的是 true negative(真阴性) :那些本来就不值得关注、并且被正确识别为“不值得关注”的对象。这里我们用“鸟”的图片来表示它们。
为了把这个比喻补完整,可以说:第 13 章里的模型,本质上就是一只 Dozer。它把大多数结节都当成“不重要”的东西,因为数据集里的大多数候选本来就不是癌症。这种做法会带来大量 false negatives。上一章末尾,我们关注的是整个训练集和验证集上的“总体正确率”。显然,这并不是一个好的自我评分方式。正如两只狗都只盯着某一个单一指标(比如 true positives 或 true negatives 的数量)一样,我们也需要一个视野更宽的指标,才能捕捉整体表现。
14.3 把 positives 和 negatives 画出来
现在我们开始构建一套视觉语言,用来描述 true / false positives / negatives。若接下来的解释略显重复,还请包涵;我们希望你对接下来要讨论的这些比例,建立起非常牢固的心智模型。请看图 14.4,它展示了某些事件,这些事件对我们的看门狗来说可能“值得关注”。
图 14.4 猫、鸟、啮齿动物和盗贼构成了我们的四个分类象限。它们由“人的标签”和“狗的分类阈值”分隔开来。
在图 14.4 中,我们会用到两条阈值。第一条是人类决定的分界线,它把盗贼与无害动物区分开来。落到具体问题中,这个标签就是我们为每一个训练或验证样本给出的人工标注。第二条则是狗自己决定的分类阈值,它决定了狗是否会对某个对象叫。对于深度学习模型来说,这对应于模型面对一个样本时所输出的预测值。
这两条阈值的组合,会把所有事件划分到四个象限中:true / false positives / negatives。我们会用较深的背景色把那些真正“值得担心”的事件标出来(毕竟坏家伙总是在黑暗里鬼鬼祟祟的)。
当然,现实比这个要复杂得多。并不存在某种“盗贼的柏拉图式理想型”,也不存在一个单一位置,使得所有盗贼都整整齐齐地落在分类阈值的某一侧。相反,图 14.5 展示的是:有些盗贼特别狡猾,有些鸟特别烦人。我们还进一步把这些实例放进了一个图中。X 轴仍然表示每个事件对于狗来说“值不值得叫”,也就是某只看门狗所感知到的 bark-worthiness。Y 轴则表示某种模糊而广泛的人类可感知特征集合,而这些东西狗感知不到。
图 14.5 每一种事件类型都对应着许多可能实例,而我们的看门狗必须逐一判断它们
由于我们的模型做的是二元分类,因此可以把“预测阈值”理解为:把模型输出的一个单一数值,与分类阈值进行比较。这就是为什么图 14.5 中的分类阈值线必须是一条完全竖直的线。
每一个盗贼都不同,所以看门狗必须面对很多不同情形,这也意味着有更多机会犯错。图中我们可以清楚看到那条把“鸟”和“盗贼”分开的斜线,但 Dozer 和 Chirpy 实际上只能感知 X 轴:对它们而言,中间那片区域里各种事件是混杂重叠的。它们必须自己选择一条竖直的 bark-worthiness 阈值线,这意味着它们不可能做到完美。有时候,正把你家家电搬上车的人,是你自己请来的洗衣机维修工;而有时候,盗贼开来的车上也可能写着“洗衣机维修”。指望一条狗去识别这些细微差别,注定会失败。
我们真正要处理的输入数据,维度非常高:我们要考虑大量 CT 体素值,还要考虑更抽象的信息,比如候选大小、在肺部中的整体位置,等等。模型的任务,就是把这些事件及其各自的属性映射进这个矩形空间,并且映射成一种形式:让我们能够用一条竖直的单线(也就是分类阈值)把 positive 和 negative 干净地分开。这个工作,正是由模型末端的 nn.Linear 层完成的。而这条竖直线的位置,就对应着我们在 13.6.1 节中见过的 classificationThreshold。当时,我们把它写死为 0.5。
注意,现实中的数据当然并不是二维的;它是在倒数第二层仍然保持很高维度,然后在输出时压缩到一维(也就是这里的 X 轴)——每个样本只剩一个标量,然后再被分类阈值一刀切开。这里我们引入第二个维度(Y 轴),只是为了表示那些模型看不到、也用不到的样本特征:例如患者年龄、性别、结节候选在肺里的具体位置,甚至是模型尚未利用到的结节局部形态特征。与此同时,这也给了我们一个方便的方式,去可视化非结节样本与结节样本之间的“混淆”。
图 14.5 中四个象限的面积,以及其中包含的样本数,将成为我们讨论模型表现的基础,因为我们可以利用这些数值之间的比例,构造出越来越复杂的指标,用以客观衡量模型到底做得有多好。俗话说,“proof is in the proportions(证据就在比例里)。”(好吧,可能其实没人这么说。)接下来,我们就要开始利用这些事件子集之间的比例,定义出更合适的指标。
14.3.1 Recall 是 Chirpy 的强项
Recall(召回率) 基本上可以理解为:“确保任何值得关注的事件都绝不漏掉!” 更形式化地说,recall 是 true positives 与所有 positives(也就是 true positives 和 false negatives 的并集)之间的比率。图 14.6 就展示了这一点。
图 14.6 Recall 是 true positives 相对于 true positives 与 false negatives 并集的比例。高 recall 意味着尽量减少 false negatives。
注意:在某些语境中,recall 也被称为 sensitivity(灵敏度) 。
要提升 recall,就必须减少 false negatives。用看门狗的话说就是:如果不确定,那就先叫,宁可错叫,也别漏掉。绝不能让那些“老鼠小偷”从你眼皮底下悄悄溜过去!
Chirpy 之所以能把 recall 做得极高,就是因为它把分类阈值一路推到最左边,从而把图 14.7 中几乎所有 positive 事件都纳入“要叫”的范围。注意,这样一来,它的 recall 值就接近 1.0——也就是说,99% 的盗贼都会被它叫出来。既然这就是 Chirpy 眼中“成功”的定义,那么在它自己看来,它干得相当漂亮。至于那大片大片的 false positives?它才不在乎!
图 14.7 Chirpy 所选择的阈值,优先考虑的是尽量减少 false negatives。每一只老鼠都要被叫出来……顺便连猫和大多数鸟也一起叫了。
14.3.2 Precision 是 Dozer 的强项
Precision(精确率) 基本上可以理解为:“除非你很确定,否则绝不要叫。” 要提高 precision,就要尽量减少 false positives。Dozer 除非确定某个东西真的是盗贼,否则不会叫。更正式地说,precision 是 true positives 相对于所有被判成 positive 的对象(也就是 true positives 与 false positives 的并集)之间的比率,如图 14.8 所示。
图 14.8 Precision 是 true positives 相对于 true positives 与 false positives 并集的比例。高 precision 意味着尽量减少 false positives。
Dozer 之所以能拥有极高的 precision,是因为它把分类阈值一路推到了最右边,尽可能把那些无关紧要的负类事件排除在“要叫”的区域之外(见图 14.9)。这与 Chirpy 的做法正好相反,因此 Dozer 的 precision 接近 1.0:它叫的东西里,99% 都真的是盗贼。这也符合它自己心目中“好看门狗”的定义,尽管它因此漏掉了大量真正该发现的事件。
图 14.9 Dozer 所选择的阈值,优先考虑的是尽量减少 false positives。猫都被放过了;只有真正的盗贼才会挨叫!
虽然 precision 和 recall 都不能单独充当评分模型的唯一指标,但它们都是训练过程中非常值得追踪的数值。接下来,我们会把这两个值也计算出来,并显示在训练程序中,然后再讨论还可以使用哪些其他指标。
14.3.3 在 logMetrics 中实现 precision 与 recall
precision 和 recall 都是训练时非常有价值的追踪指标,因为它们能清楚揭示模型到底在怎么行为。如果其中任何一个掉到 0(正如我们在第 13 章中看到的那样!),那大概率意味着模型已经退化成某种病态行为。我们可以根据这种退化的具体形式,去推断应该往哪里调查、往哪里做实验,把训练重新拉回正轨。现在,我们就来更新 logMetrics 函数,把 precision 和 recall 加到每个 epoch 的输出中,让它们和我们已经有的 loss 以及正确率指标形成互补。
事实证明,计算 precision 和 recall 所需的一些值,我们其实已经在算了,只不过之前起了不同的名字而已(training.py:315, LunaTrainingApp.logMetrics)。
代码清单 14.1 计算 true / false positives 和 negatives 的计数
neg_count = int(negLabel_mask.sum())
pos_count = int(posLabel_mask.sum())
trueNeg_count = neg_correct = int((negLabel_mask & negPred_mask).sum())
truePos_count = pos_correct = int((posLabel_mask & posPred_mask).sum())
falseNeg_count = pos_count - pos_correct
falsePos_count = neg_count - neg_correct
这里我们可以看到,neg_correct 其实就是 trueNeg_count!这很合理,因为“非结节”本来就是我们的“负类”(negative,理解为“阴性诊断”),如果分类器把它判对了,那自然就是一个 true negative。同样地,结节样本被正确判出来,就是 true positive。
我们确实还需要额外引入 false positive 和 false negative 的变量。不过这很直接:对于 false positive,只要拿负类样本总数减去负类判对数,剩下的就是“本来是非结节,却被错判成正类”的样本数。既然它们被判成了 positive,但其实不是真的 positive,所以它们是 false positives。false negative 的计算形式完全一样,只不过换成了正类的计数。
有了这些值之后,我们就可以计算 recall 和 precision,并把它们存进 metrics_dict(training.py:333, LunaTrainingApp.logMetrics)。
代码清单 14.2 根据分类计数计算 recall 与 precision
recall = metrics_dict['pr/recall'] = \
truePos_count / np.float32(truePos_count + falseNeg_count)
precision = metrics_dict['pr/precision'] = \
truePos_count / np.float32(truePos_count + falsePos_count)
14.3.4 我们的终极性能指标:F1 score
precision 和 recall 虽然都很有用,但它们谁都不能完整地捕捉我们真正需要的那种模型评价方式。正如我们从 Chirpy 和 Dozer 身上看到的那样,只要把分类阈值往某个方向推,你就可以“刷”出一个单项看起来很好的 precision 或 recall,可现实中的实际效用却可能糟糕透顶。我们需要一种能够把这两个值结合起来的方式,并且这种方式不容易被这样的“投机取巧”钻空子。正如图 14.10 所示,现在是引入我们终极指标的时候了。
图 14.10 本章的一组主题,此处聚焦于最终的 F1 score 指标
把 precision 和 recall 结合起来的通行做法,是使用 F1 score(en.wikipedia.org/wiki/F-scor…)。和其他指标一样,F1 的取值范围也在 0 到 1 之间:0 表示分类器几乎没有现实中的预测能力,1 表示预测完美无缺。我们也会更新 logMetrics,把它加进去。
代码清单 14.3 计算 F1 score
metrics_dict['pr/f1_score'] = \
2 * (precision * recall) / (precision + recall)
乍一看,F1 score 的公式似乎比想象中复杂,而且并不直观地看出它到底是如何平衡 precision 和 recall 的。不过,这个公式有许多很好的性质,而且通常要比一些更简单的替代方案表现得更合理。
一个最直接的想法,是干脆把 precision 和 recall 做平均。可惜,这样的话,avg(p=1.0, r=0.0) 和 avg(p=0.5, r=0.5) 都会得到同样的分数 0.5。而正如我们前面讨论过的,只要 precision 或 recall 其中之一为 0,这种分类器通常就已经没什么用了。让一个毫无用处的模型,和一个至少还有点用的模型拿到同样的非零分数,这就已经足够说明“简单平均”根本不适合作为有意义的指标。
尽管如此,我们还是可以视觉上把“平均值”与 F1 对比一下,如图 14.11 所示。你会立刻注意到几件事。首先,平均值的等高线没有明显的弯折或“肘部”。正因为如此,precision 或 recall 可以很轻易地向一边倾斜!总会存在一种策略,让你通过把 recall 拉到 100%(也就是 Chirpy 的做法)去最大化得分。这样一来,单凭这个加法式指标,一上来就有了 0.5 的下限!一个质量指标,如果轻轻松松就能拿到至少 50 分,总感觉哪里不对。
图 14.11 使用 avg(p, r) 来计算最终分数。颜色越浅越接近 1.0,越深越接近 0.0
从数学上说,图 14.11 左边实际上是在对 precision 和 recall 取 算术平均(en.wikipedia.org/wiki/Arithm…),而它们本身是“率”而不是“可数标量”。对“率”取算术平均,通常并不会给出特别有意义的结果。F1 score 则是这两个率的 调和平均(en.wikipedia.org/wiki/Harmon…)的另一种写法,而这种方式更适合用来组合这类数值。
再来看 F1:当 recall 很高但 precision 很低时,只要牺牲一点 recall 去换取哪怕少量 precision,分数都会朝那个“平衡甜点区”移动。它有一个很明显、很深的“肘部”,模型很容易被引导滑向那个区域。这种鼓励 precision 和 recall 取得平衡的特性,正是我们希望一个评分指标具备的。
假设我们还是想要一个更简单的指标,但又不希望它奖励任何倾斜,那么我们可能会想:干脆取 precision 和 recall 的最小值好了(见图 14.12)。
图 14.12 使用 min(p, r) 来计算最终分数
这个做法有个优点:如果任意一个值为 0,那么分数也是 0;而想拿到 1.0,则必须两个值都为 1.0。不过,它依然不够理想。比如说,如果你做了一次模型改动,使 recall 从 0.7 提升到 0.9,而 precision 保持 0.5 不变,那么最终分数不会增加;反过来,即使 recall 掉到 0.6,也同样不会让分数变化!也就是说,虽然这个指标确实会惩罚 precision 和 recall 失衡,但它没有充分表达两者之间那些更细腻的变化关系。正如我们已经看到的,仅靠挪动分类阈值,就很容易在二者之间做权衡;我们希望自己的指标能真实反映出这种权衡。
这就意味着,为了更好地达到目标,我们得接受一点额外的复杂性。比如,我们也可以把这两个值直接相乘,如图 14.13 所示。这种方法同样具备前述好性质:任意一个值为 0,则得分为 0;得分为 1.0 说明两个输入都完美。而且,在较低数值区域,它确实也偏向于鼓励 precision 和 recall 保持均衡;只是当它接近完美区域时,走势会越来越线性。这并不理想,因为在那个阶段,我们其实希望两个值都同时大幅提高,才算得上真正 meaningful 的改进。
图 14.13 使用 mult(p, r) 来计算最终分数
注意:这里实际上是在对两个率取 几何平均(en.wikipedia.org/wiki/Geomet…),这同样并不会产生真正有意义的结果。
此外,还有个问题:从 (0, 0) 到 (0.5, 0.5) 的几乎整个象限区域,分值都很接近 0。正如我们会看到的,在模型设计早期阶段,让指标对这一区域内的变化足够敏感,是很重要的。
尽管“相乘”作为评分函数并不是不可行(不像前面那几个方案那样一眼就能被判出局),但接下来我们还是会使用 F1 score 作为后续评估分类模型性能的主指标。
更新日志输出,把 precision、recall 和 F1 都打出来
现在我们已经有了这些新指标,把它们加进日志输出就很简单了。我们会在训练集和验证集每个 epoch 的主日志语句里,加入 precision、recall 和 F1(training.py:341, LunaTrainingApp.logMetrics)。
代码清单 14.4 在 epoch 日志中输出 precision、recall 和 F1
log.info(
f"E{epoch_ndx} {mode_str:8} {metrics_dict['loss/all']:.4f} loss, "
f"{metrics_dict['correct/all']:-5.1f}% correct, "
f"{metrics_dict['pr/precision']:.4f} precision, " #1
f"{metrics_dict['pr/recall']:.4f} recall, " #1
f"{metrics_dict['pr/f1_score']:.4f} f1 score" #1
)
#1 更新后的 format string
除此之外,我们还会在负类与正类的日志中,加入“正确识别数 / 总样本数”的精确计数(training.py:353, LunaTrainingApp.logMetrics)。
代码清单 14.5 记录负样本的分类正确情况
log.info(
f"E{epoch_ndx} {mode_str + '_neg':8} {
↪metrics_dict['loss/neg']:.4f} loss, "
f"{metrics_dict['correct/neg']:-5.1f}% correct (
↪{neg_correct} of {neg_count})"
)
新的正类日志语句与此几乎完全相同。
14.3.5 用这些新指标来看,我们的模型表现如何?
既然这些闪亮的新指标都已经实现好了,那就来真正跑一跑吧。下面我们先展示一段 Bash shell 的输出,然后再讨论这些结果。趁系统在 crunch numbers 的时候,你也可以先继续往下读;具体训练时间视硬件情况而定,可能要半小时左右。(如果耗时远远超过这个量级,请确认你已经先运行过 prepcache 脚本。)到底多久跑完,会取决于你的 CPU、GPU 和磁盘速度;我们这边一台带 SSD 和 GTX 1080 Ti 的机器,大约每个完整 epoch 需要 20 分钟:
$ ../.venv/bin/python -m p2ch14.training
Starting LunaTrainingApp...
...
E1 LunaTrainingApp
.../p2ch14/training.py:274: RuntimeWarning:
↪ invalid value encountered in double_scalars
metrics_dict['pr/f1_score'] = 2 * (precision * recall) /
↪ (precision + recall) #1
E1 trn 0.0025 loss, 99.8% correct, 0.0000 prc, 0.0000 rcl, nan f1
E1 trn_ben 0.0000 loss, 100.0% correct (494735 of 494743)
E1 trn_mal 1.0000 loss, 0.0% correct (0 of 1215)
.../p2ch14/training.py:269: RuntimeWarning:
↪ invalid value encountered in long_scalars
precision = metrics_dict['pr/precision'] = truePos_count /
↪ (truePos_count + falsePos_count)
E1 val 0.0025 loss, 99.8% correct, nan prc, 0.0000 rcl, nan f1
E1 val_ben 0.0000 loss, 100.0% correct (54971 of 54971)
E1 val_mal 1.0000 loss, 0.0% correct (0 of 136)
#1 这些 RuntimeWarning 的具体计数和行号,在不同运行中可能会略有差异。
真糟糕。我们收到了若干 warning,而且既然有些值算成了 nan,那大概率说明某个地方发生了“除以零”。来看看能推断出什么。
首先,由于训练集中一个正样本都没有被判成正类,所以 precision 和 recall 都是 0,于是 F1 score 的公式就变成了除以 0。其次,在验证集上,truePos_count 和 falsePos_count 都是 0,因为模型压根没有把任何东西标成正类。于是 precision 的分母也变成了 0;这也正好解释了我们看到的另一个 RuntimeWarning。
另外,有极少数负类训练样本被错判成了正类(494,735 个被判负,494,743 个总负样本,也就是有 8 个样本被误判)。这乍一看有点奇怪,但请记得:我们记录训练结果,是在整个 epoch 过程中不断收集的,而不是像验证那样只看 epoch 末尾时刻的模型状态。这意味着第一批数据的输出,实际上就是纯随机的。既然如此,第一批里有少数样本被随机判成正类,也就不足为奇了。
注意:由于网络权重是随机初始化的,训练样本每个 epoch 的选取与顺序也都是随机的,因此不同运行之间很可能会表现出略有不同的行为。完全可复现的行为当然有时很 desirable,但这超出了本书第二部分当前想传达的核心概念范围。
总之,这一幕有点惨。换上这些新指标之后,我们的成绩从 A+ 掉成了“运气好的话还能拿个零分”;运气不好时,糟糕到连分数都不是个数。疼。话虽如此,从长远看这其实是好事。因为从第 13 章开始,我们就已经知道模型表现垃圾。如果这些指标告诉我们的不是“垃圾”,那反倒说明指标本身有根本性缺陷!
14.4 理想的数据集应该长什么样?
在为当前局面哀叹之前,不如先想一想:我们到底希望模型做成什么样。图 14.14 告诉我们:第一步需要先让数据更平衡,这样模型才能真正开始正常训练。我们先一步步铺垫出通往那里的逻辑路径。
图 14.14 本章的一组主题,此处聚焦于如何让正负样本更平衡
回想图 14.5,以及前面对分类阈值的讨论。单纯通过移动阈值来改善结果,其效果是有限的——因为正类和负类之间的重叠实在太严重了。(记住,这些图只是分类空间的示意图,并不代表真实 ground truth。)
相反,我们真正希望看到的是图 14.15 那样的画面:这里,标签阈值几乎是一条竖线。之所以理想,是因为那意味着标签阈值和分类阈值可以比较好地对齐。与此同时,大多数样本都聚集在图的两端。想做到这一点,一方面要求数据本身足够可分,另一方面要求模型具备执行这种分离的能力。我们的模型当前其实已经有足够的 capacity,所以问题不在这里。现在,先来看看数据。
图 14.15 一个训练良好的模型,能够把数据清晰分开,从而可以用很少的代价就选出合适的分类阈值
还记得我们的数据极其不平衡吗?负样本与正样本的比例大约是 400:1。这已经不是“不平衡”,而是压倒性失衡了!图 14.16 展示了这种情况大概长什么样。难怪那些“真的结节”样本会被淹没在人群里!
图 14.16 一个高度不平衡的数据集,大致近似了我们 LUNA 分类数据中的失衡程度
现在要把话说清楚:最终,我们的模型是能够应对这种不平衡数据的。假如我们愿意等上“一个天文数字般多”的 epoch,不调整平衡设置,理论上也许仍然能把模型一路训练到位。(这并不一定真的成立,但至少看起来是 plausible 的,毕竟 loss 确实一直在变好……)可惜,我们都是忙人,事情很多,与其把 GPU 烤到宇宙热寂,不如想办法把训练数据调整得更接近“理想状态”,也就是改变训练时所使用的类别平衡。
14.4.1 让数据看起来少一点“现实世界”,多一点“理想世界”
最好的情况,当然是让正样本相对更多一些。尤其是在训练初期、模型还在从随机混沌走向稍微有点结构的时候,如果正样本太少,它们就很容易被淹没掉。
这种现象的发生机制其实比较微妙。记住,由于网络权重一开始是随机的,所以模型对每个样本的输出也是随机的(虽然最终会被压缩到 [0,1] 这个范围内)。
注意:我们的 loss 函数其实是 nn.CrossEntropyLoss,严格来说它直接作用于原始 logits,而不是类别概率。为了便于讨论,我们暂时忽略这一点,把 loss 以及“标签与预测之间的差距”粗略视作一回事。
那些数值上本来就比较接近正确标签的预测,不会对网络权重造成太大变化;而那些与正确答案相差很大的预测,则会推动权重大幅调整。既然模型初始是随机权重,我们就可以粗略认为,在大约 50 万个训练样本(准确说是 495,958 个)里,会出现以下四组:
- 250,000 个负样本会被预测为负类(0.0–0.5),它们对网络权重朝“更倾向负类”方向的推动非常小
- 250,000 个负样本会被预测为正类(0.5–1.0),它们会强烈推动网络权重朝“更倾向负类”方向移动
- 500 个正样本会被预测为负类,它们会推动网络权重朝“更倾向正类”方向移动
- 500 个正样本会被预测为正类,对权重几乎不会造成什么变化
注意:别忘了,实际预测值是位于 0.0 到 1.0 之间的连续实数,因此这些组并不存在非常刚性的边界。
真正关键的点在于:第 1 组和第 4 组对训练几乎没有什么影响。真正影响训练方向的,是第 2 组和第 3 组之间是否能彼此抵消,足以避免网络塌缩成那种“永远只输出同一种答案”的病态状态。问题在于,第 2 组比第 3 组大了 500 倍;而我们用的 batch size 是 32,也就是说大概每 500/32 ≈ 15 个 batch 才能看到 1 个正样本。这意味着 15 个训练 batch 里,大约有 14 个 batch 会是 100% 负样本,它们只会一边倒地把模型权重全部拉向“预测负类”。正是这种极端失衡的拉扯,造成了我们一直看到的退化行为。
我们真正想要的,是训练时正负样本数量相当。这样一来,在训练初期,两个类别中大约各有一半会被判错,于是前面说的第 2 组和第 3 组在规模上就会接近得多。我们还希望每个 batch 中都混合包含负样本与正样本。平衡的数据会让这场“拔河”趋于均衡,而每个 batch 中类别的混合,也会让模型有机会真正去学习区分这两个类别。由于 LUNA 数据里真正的正样本数量就这么一点而且是固定的,我们只能接受这样一个现实:把现有这些正样本在训练中反复重复呈现。
Discrimination(判别能力)
这里我们把 discrimination 定义为:“把两个类别区分开的能力。” 训练一个模型,让它能够把“真正的结节候选”与正常解剖结构区分开,这正是本书第二部分一直在做的事情。
不过,其他语境下的 discrimination 这个词也可能涉及更麻烦的问题。虽然这超出了我们当前项目讨论范围,但现实世界数据训练出来的模型,确实面临一个更大的难题:如果数据本身采集自带有现实偏见的来源(例如逮捕和定罪率中的种族偏见,或者社交媒体上的各种偏见),而在数据准备和训练过程中又没有加以纠正,那么最终训练出的模型就会持续表现出训练数据中已有的那种偏见。和人一样,种族主义也是学来的。
因此,几乎任何基于“互联网大范围数据源”训练出来的模型,都会在某种程度上被污染,除非你极其谨慎地清洗这些偏见。值得注意的是,这和我们在本书第二部分中面对的问题一样,目前也仍然是未解决问题。
还记得第 13 章中那位教授吗?期末考试里 99 道题答案是 False,1 道题答案是 True。下个学期,有人告诉教授:“你应该把 True 和 False 的比例弄得更均衡一些。” 于是教授决定:加一场期中考试,里面 99 道题答案是 True,1 道题答案是 False。——“问题解决了!”
显然,真正正确的做法,是把 True 和 False 以一种混合方式交错排列,让学生无法利用整张试卷的大结构去“投机取巧”地答题。一个学生也许会看出“奇数题都是真,偶数题都是假”这种模式,但 PyTorch 的 batching 机制不会让模型有机会察觉或利用这种“题号模式”。因此,我们的训练数据集需要被调整成正负样本交替出现的方式,如图 14.17 所示。
图 14.17 如果数据极端不平衡,那么一整个 batch 接一整个 batch 都可能只有负类,远远早于第一个正类样本出现;而平衡数据则可以做到大致隔一个样本就交替一次
不过,我们不会对验证集做任何平衡处理。因为模型最终必须在真实世界中正常工作,而真实世界就是不平衡的(毕竟我们的原始数据就是从那里来的)。那么,该如何实现训练集的平衡呢?来讨论一下可选方案。
Sampler 可以重塑数据集
DataLoader 的一个可选参数是 sampler=...。这个参数允许 DataLoader 覆盖传入数据集本来的迭代顺序,转而按照我们想要的方式去重排、限制或者重新强调底层数据。当你面对一个你无法控制的数据集时,这会非常有用。直接拿公开数据集做一些重塑,使之适配自己的需求,远比从头重写那个数据集类要省事得多。
但问题在于,许多 sampler 能实现的那些“变形”,要求我们打破底层数据集的封装。比如,假设我们有一个类似 CIFAR-10(www.cs.toronto.edu/~kriz/cifar…)的数据集,它原本有 10 个均衡类别,而我们现在想让其中某一个类别(比如“airplane”)占到所有训练图像的 50%。我们当然可以使用 WeightedRandomSampler(mng.bz/vZ5m),把属于“airplane”的样本索引赋予更高权重。但问题在于:要构造这个 weights 参数,我们首先必须提前知道“哪些索引是飞机”。
而正如前面讨论过的,Dataset API 只要求子类实现 __len__ 和 __getitem__,却没有任何直接的方法让我们问一句:“哪些样本是飞机?” 于是,要么我们必须预先把所有样本都加载一遍,逐个询问它们属于哪个类别;要么我们只能打破封装,直接去窥视某个 Dataset 子类的内部实现,碰运气看看所需信息是不是容易读出来。
在我们自己能够直接控制数据集的情况下,这两种做法都不算理想。因此,本书第二部分选择把所有必要的数据塑形逻辑直接写进 Dataset 子类内部,而不是依赖外部 sampler。
在 Dataset 中实现类别平衡
我们将直接修改 LunaDataset,让它在训练时呈现一个 positive : negative = 1 : 1 的平衡比例。我们会分别维护两份训练样本列表:一份负样本,一份正样本;然后交替地从这两份列表中返回样本。这样一来,模型就无法再靠“对所有样本都回答 false”来刷高分。与此同时,正负类别在 batch 中交错混合,也会迫使权重更新朝着真正区分类别的方向前进。
我们给 LunaDataset 增加一个 ratio_int,它将决定第 N 个样本应当属于哪一类,同时也会帮助我们按标签分别保存样本(dsets.py:217, class LunaDataset)。
代码清单 14.6 在 LunaDataset 中引入基于比例的标签平衡
class LunaDataset(Dataset):
def __init__(self,
val_stride=0,
isValSet_bool=None,
ratio_int=0,
):
self.ratio_int = ratio_int
# ... line 228
self.negative_list = [
nt for nt in self.candidateInfo_list if not nt.isNodule_bool
]
self.pos_list = [
nt for nt in self.candidateInfo_list if nt.isNodule_bool
]
# ... line 265
def shuffleSamples(self): #1
if self.ratio_int:
random.shuffle(self.negative_list)
random.shuffle(self.pos_list)
#1 我们会在每个 epoch 的开始调用它,以随机化样本呈现顺序。
这样一来,我们就有了专门针对每个标签的独立列表。有了这些列表之后,就能更容易地控制“给定某个数据集索引时,应该返回什么标签的样本”。为了确认索引逻辑没写错,我们可以先手动推演一下想要的排序。假设 ratio_int = 2,也就是负样本 : 正样本 = 2:1,那就意味着每三个位置里,应该有一个正样本:
DS Index 0 1 2 3 4 5 6 7 8 9 ...
Label + - - + - - + - - +
Pos Index 0 1 2 3
Neg Index 0 1 2 3 4 5
数据集索引与正样本索引之间的关系很简单:把数据集索引除以 3,再向下取整即可。负样本索引稍微复杂一点,因为你需要先从数据集索引里减去 1,再减去最近一次出现的正样本索引。把这个逻辑实现进 LunaDataset,大致如下(dsets.py:286, LunaDataset.__getitem__)。
代码清单 14.7 返回平衡后的正负样本
def __getitem__(self, ndx):
if self.ratio_int: #1
pos_ndx = ndx // (self.ratio_int + 1)
if ndx % (self.ratio_int + 1): #2
neg_ndx = ndx - 1 - pos_ndx
neg_ndx %= len(self.negative_list) #3
candidateInfo_tup = self.negative_list[neg_ndx]
else:
pos_ndx %= len(self.pos_list)
candidateInfo_tup = self.pos_list[pos_ndx]
#1 ratio_int = 0 表示使用原始数据本来的类别比例。
#2 余数非 0,说明当前位置应当是负样本。
#3 一旦索引越界,就从头 wraparound。
这段代码稍微有点绕,但如果你手工桌面推演一下,就会明白。需要注意的是:当 ratio 比较小的时候,我们会在跑完整个数据集之前,就先把正样本用完。我们通过在索引 self.pos_list 之前先对 pos_ndx 取模,来处理这个问题。虽然由于负样本很多,neg_ndx 理论上几乎不可能溢出,但我们还是也对它做了取模,以防以后改动逻辑时导致它也出现越界。
我们还会修改数据集的长度。虽然这一步不是绝对必要的,但它能让每个 epoch 跑得更快一些。这里我们直接把 __len__ 硬编码成 200,000(dsets.py:280, LunaDataset.__len__)。
代码清单 14.8 数据集长度的调整
def __len__(self):
if self.ratio_int:
return 200000
else:
return len(self.candidateInfo_list)
现在我们已经不再被原始样本总数绑死了,而且在平衡训练集的前提下,完整跑一个“全样本 epoch”其实也没什么意义,因为那意味着正样本必须被反复重复很多很多次。把长度定为 200,000,有几个好处:可以缩短“启动一次训练到看到结果”之间的等待时间(更快的反馈永远是好事),同时还能让每个 epoch 对应一个干净、整齐的样本数。你完全可以根据自己的需要,把 epoch 的长度调成别的数。
为了完整性,我们还会给命令行接口加一个参数(training.py:31, class LunaTrainingApp)。
代码清单 14.9 增加 --balanced 命令行参数
class LunaTrainingApp:
def __init__(self, sys_argv=None):
# ... line 52
parser.add_argument('--balanced',
help="Balance the training data to half positive, half negative.",
action='store_true',
default=False,
)
然后,再把这个参数传给 LunaDataset 构造函数(training.py:137, LunaTrainingApp.initTrainDl)。
代码清单 14.10 向 LunaDataset 传入新的平衡参数
def initTrainDl(self):
train_ds = LunaDataset(
val_stride=10,
isValSet_bool=False,
ratio_int=int(self.cli_args.balanced), #1
)
#1 这里利用了 Python 中 True 可以自动转成整数 1。
现在一切都准备好了。来跑一下看看!
14.4.2 将平衡版 LunaDataset 的训练结果,与之前的 run 进行对比
回忆一下,之前“不做平衡”的训练输出大概是这样的:
$ python -m p2ch14.training
...
E1 LunaTrainingApp
E1 trn 0.0185 loss, 99.7% correct, 0.0000 precision, 0.0000 recall,
↪ nan f1 score
E1 trn_neg 0.0026 loss, 100.0% correct (494717 of 494743)
E1 trn_pos 6.5267 loss, 0.0% correct (0 of 1215)
...
E1 val 0.0173 loss, 99.8% correct, nan precision, 0.0000 recall,
↪ nan f1 score
E1 val_neg 0.0026 loss, 100.0% correct (54971 of 54971)
E1 val_pos 5.9577 loss, 0.0% correct (0 of 136)
但如果加上 --balanced,我们会看到:
$ python -m p2ch14.training --balanced
...
E1 LunaTrainingApp
E1 trn 0.1734 loss, 92.8% correct, 0.9363 precision, 0.9194 recall,
↪ 0.9277 f1 score
E1 trn_neg 0.1770 loss, 93.7% correct (93741 of 100000)
E1 trn_pos 0.1698 loss, 91.9% correct (91939 of 100000)
...
E1 val 0.0564 loss, 98.4% correct, 0.1102 precision, 0.7941 recall,
↪ 0.1935 f1 score
E1 val_neg 0.0542 loss, 98.4% correct (54099 of 54971)
E1 val_pos 0.9549 loss, 79.4% correct (108 of 136)
这看起来好多了!我们在负样本上的正确率只损失了大约 1.6%(val_neg: 100% → 98.4%),却换来了 79% 的正样本正确识别率(val_pos: 0% → 79.4%)。我们又回到了一个相当体面的 “B” 档水平!别忘了,这还是在只呈现了 200,000 个训练样本之后得到的结果,而不是像不平衡版本那样跑完 50 多万样本。也就是说,我们在不到一半时间里,就得到了明显更好的结果。
不过,正如第 13 章那样,这个结果依然有迷惑性,因为类别失衡实在太严重了。我们来算一笔账。验证集中,大约每 400 个负样本才对应 1 个正样本。即使只把 1% 的负样本误判成正类,也会造成大问题:
- 验证集中负样本总数:54,971
- 验证集中正样本总数:136
- 负样本若以 1% 错误率产生的 false positives:54,971 × 0.01 ≈ 550
- false positives 与真实正样本总数的比例:550 ÷ 136 ≈ 4:1
这意味着:即使在负样本上达到了 99% 的准确率,我们仍然会错误标记出大约 4 倍于整个验证集中真实正样本数量 的假阳性!这会严重削弱模型的实际实用价值,即便它的总体正确率看起来很高。
尽管如此,这已经明显比第 13 章那种“完全错误的行为”强太多了,也远远好过抛一枚随机硬币。事实上,我们甚至已经几乎跨进了“在现实场景里开始有点实际用处”的门槛。还记得那位过劳的放射科医生,需要逐一检查 CT 中每一个细小斑点吗?现在,我们已经有了一个系统,能够 reasonably 地把 98.4% 的负样本筛掉。这是一种巨大帮助,因为它几乎意味着机器辅助下的人类生产率提升了一个数量级。
当然,还有那个恼人的问题:20.6% 的正样本仍然被漏掉了(因为我们只识别出了 79.4%)。或许,再多训练几个 epoch 会有帮助。来看看(而且再次提醒,每个 epoch 依然可能要跑至少 10 分钟):
$ python -m p2ch14.training --balanced --epochs 20
...
E2 LunaTrainingApp
E2 trn 0.0432 loss, 98.7% correct, 0.9866 precision, 0.9879 recall,
↪ 0.9873 f1 score
E2 trn_ben 0.0545 loss, 98.7% correct (98663 of 100000)
E2 trn_mal 0.0318 loss, 98.8% correct (98790 of 100000)
E2 val 0.0603 loss, 98.5% correct, 0.1271 precision, 0.8456 recall,
↪ 0.2209 f1 score
E2 val_ben 0.0584 loss, 98.6% correct (54181 of 54971)
E2 val_mal 0.8471 loss, 84.6% correct (115 of 136)
...
E5 trn 0.0578 loss, 98.3% correct, 0.9839 precision, 0.9823 recall,
↪ 0.9831 f1 score
E5 trn_ben 0.0665 loss, 98.4% correct (98388 of 100000)
E5 trn_mal 0.0490 loss, 98.2% correct (98227 of 100000)
E5 val 0.0361 loss, 99.2% correct, 0.2129 precision, 0.8235 recall,
↪ 0.3384 f1 score
E5 val_ben 0.0336 loss, 99.2% correct (54557 of 54971)
E5 val_mal 1.0515 loss, 82.4% correct (112 of 136)...
...
E10 trn 0.0212 loss, 99.5% correct, 0.9942 precision, 0.9953 recall,
↪ 0.9948 f1 score
E10 trn_ben 0.0281 loss, 99.4% correct (99421 of 100000)
E10 trn_mal 0.0142 loss, 99.5% correct (99530 of 100000)
E10 val 0.0457 loss, 99.3% correct, 0.2171 precision, 0.7647 recall,
↪ 0.3382 f1 score
E10 val_ben 0.0407 loss, 99.3% correct (54596 of 54971)
E10 val_mal 2.0594 loss, 76.5% correct (104 of 136)
...
E20 trn 0.0132 loss, 99.7% correct, 0.9964 precision, 0.9974 recall,
↪ 0.9969 f1 score
E20 trn_ben 0.0186 loss, 99.6% correct (99642 of 100000)
E20 trn_mal 0.0079 loss, 99.7% correct (99736 of 100000)
E20 val 0.0200 loss, 99.7% correct, 0.4780 precision, 0.7206 recall,
↪ 0.5748 f1 score
E20 val_ben 0.0133 loss, 99.8% correct (54864 of 54971)
E20 val_mal 2.7101 loss, 72.1% correct (98 of 136)
唉,这一大段输出要滚半天才能看到真正想看的数字。我们硬着头皮抓住重点:看 val_mal XX.X% correct 这一列就够了。epoch 2 时是 84.6% ;到 epoch 5 掉到 82.4% ;到了 epoch 20,则进一步掉到 72.1% ——这已经是相当明显的下降趋势了!
注意:如前面所说,由于网络权重的随机初始化,以及每个 epoch 训练样本的随机抽取与排序,不同 run 的具体表现会有所不同。
而训练集上的数字看起来并没有同样的问题。训练到 20 个 epoch 时,负样本正确率是 99.6%,正样本正确率是 99.7%。为什么训练集和验证集差别会这么大?
14.4.3 识别过拟合的症状
我们现在看到的,就是非常典型的 overfitting(过拟合) 。来看图 14.18 中正样本 loss 的曲线。
图 14.18 我们的正样本 loss 显示出明显的过拟合迹象,因为训练 loss 和验证 loss 正在朝相反方向变化
从图中你可以看到:正样本上的训练 loss 已经几乎为 0,也就是说,每一个正类训练样本都被预测得几乎完美。然而,正样本上的验证 loss 却在不断升高,这意味着模型在真实世界中的表现很可能正在变差。到这个时候,通常最好的做法就是把训练停下来,因为模型已经不再继续改善了。
提示:一般来说,如果模型在训练集上表现越来越好,而在验证集上却越来越差,那就说明模型开始过拟合了。
这里非常重要的一点是:我们必须盯住正确的指标,因为这个问题只发生在正样本 loss 上。如果你看的是“总体 loss”,一切看起来都还不错!这是因为验证集本来就没有做平衡,所以总体 loss 完全被负样本主导了。如图 14.19 所示,负样本并没有显示出同样的发散行为。相反,负样本 loss 看起来还非常漂亮!这是因为负样本数量比正样本多 400 倍,因此模型几乎不可能“记住”每一个负样本的细节。可正样本训练集只有 1,215 个样本。虽然这些正样本在训练中被反复重复了很多次,但这并不会让它们变得更难被记住。结果是:模型正在从学习一般规律,逐步滑向“记忆这 1,215 个正样本的怪癖细节”,然后对所有“不长得像这 1,215 个已记住样本”的东西,一律回答负类。这包括了负类训练样本,也包括验证集中的所有样本(不管它们是正是负)。
图 14.19 我们的负样本 loss 没有表现出任何过拟合迹象
当然,某种程度上的泛化依然存在,因为模型仍然能把大约 70% 的正类验证样本判对。我们需要做的,只是改变模型的训练方式,让训练集和验证集都朝正确方向一起变化。
14.5 重新审视过拟合问题
我们在第 6 章里已经提到过过拟合的概念,而现在到了更仔细讨论“如何处理这个常见问题”的时候。训练模型的目标,是让它学会识别我们感兴趣类别在数据中表现出来的那些一般性特征。这些一般性特征存在于该类别的部分样本或多数样本中,具有可泛化性,因此可以用来预测那些没见过的新样本。当模型开始学习训练集中特定样本的特殊性时,就发生了过拟合;此时模型会逐渐失去泛化能力。如果这段话有点抽象,那就再用一个类比。
14.5.1 一个过拟合的“人脸预测年龄”模型
假设我们有一个模型,输入是一张人脸照片,输出是这张脸对应的年龄(单位:岁)。一个好的模型,应当学会识别年龄特征,例如皱纹、白发、发型、穿衣选择等等,并据此建立对“不同年龄的人通常长什么样”的一般化模型。于是,当它看到一张新照片时,会综合考虑“发型比较保守”“戴着老花镜”“有皱纹”等特征,从而得出“大概 65 岁左右”的判断。
而一个过拟合模型则不同。它不会学年龄特征,而是记住具体的人,通过个体身份特征来记忆。“这个发型加这副眼镜,说明是 Frank。他 62.8 岁。”“哦,这道疤是 Harry。他 39.3 岁。” 之类的。于是,当它看到一个陌生人时,它根本认不出来,也就完全不知道该预测几岁。
更糟糕的是,如果你给它看 Frank Jr. 的照片(尤其是 Junior 戴上眼镜之后,跟他老爸长得简直一模一样!),模型就会说:“我觉得这是 Frank。他 62.8 岁。” ——完全不管 Junior 其实比老爸小 25 岁!
过拟合通常是因为:相对于模型“把答案直接背下来”的能力,训练样本数量太少了。普通人可以轻松记住自己直系亲属的生日,但如果让他们去猜一个比小村庄还大的人群的年龄,就不得不依赖一般化规律了。
我们的“人脸预测年龄”模型,恰恰就有能力直接记住那些“不太符合常见年龄特征”的人的照片。正如我们在本书第一部分讨论过的,模型 capacity(容量)是一个稍微有点抽象的概念,但大致上可以理解为:模型参数数量 × 这些参数被高效利用的程度。当模型的容量,相对于“记住训练集中那些难样本所需的数据量”而言太高时,模型就很容易开始在这些难样本上过拟合。
14.6 通过数据增强防止过拟合
现在,是时候把模型训练从“不错”推进到“真正好”了。我们还需要完成图 14.20 里的最后一步。
图 14.20 本章的一组主题,此处聚焦于数据增强
所谓对数据集做 augmentation(增强),就是对单个样本施加一些合成变换,从而得到一个“有效规模比原始数据集更大”的新数据集。通常,增强的目标是:变换后的样本依然应当能代表原始样本所属的同一个一般类别,但又不能被模型轻易地和原样本一起直接记住。如果增强做得好,它就能把训练集的有效大小扩展到超过模型的直接记忆能力,迫使模型越来越依赖泛化,这正是我们想要的。尤其是在数据有限时,这种做法特别有价值。
当然,并不是所有增强方式都同样有用。继续拿前面那个“人脸预测年龄”模型举例:我们当然可以很容易地把每张图像四个角上像素的红色通道随机改成一个 0–255 的值,这样一来数据集的有效大小一下子能扩大到原来的 40 亿倍。但这显然没什么意义,因为模型完全可以轻松学会无视这些角落里的红点,而图像其余部分依然和原始样本一样容易被记住。相比之下,把图像左右翻转就有意义得多。这样最多只会让数据集翻倍,但每张图像会变得更适合用于训练。衰老的一般特征本来就和左右方向无关,所以镜像图像仍然具有代表性;与此同时,人脸通常也不可能做到完全左右对称,因此镜像版本也不太可能和原图一起被轻易记住。
14.6.1 具体的数据增强技术
我们将实现 5 种具体的数据增强方式。实现会允许我们对它们做单独实验,也允许任意组合使用。具体包括:
- 沿上下、左右、前后方向翻转图像
- 让图像在空间上偏移几个体素
- 对图像做轻微缩放
- 围绕头脚轴旋转图像
- 向图像中添加噪声
注意:具体该使用哪些图像增强方式,必须依赖领域知识。比如,对人脸图像做左右翻转通常还合理,但如果上下翻转,那就会变成一个完全不自然、不真实的训练样本。
对于每一种增强,我们都必须确保:变换后的样本仍然保持原有代表性,同时又要和原样本足够不同,值得拿来训练。
我们会定义一个 getCtAugmentedCandidate 函数,专门负责获取标准 CT candidate chunk,并对它进行修改。总体思路是:先定义一个仿射变换矩阵(mng.bz/4n25),然后配合 PyTorch 的 affine_grid(mng.bz/QwN1)和grid_sample(mng.bz/X76l)函数,对 candidate 进行重新采样。
首先,我们拿到 ct_chunk——要么直接从缓存中取,要么回头去加载 CT(这一步将来在我们自己构造候选中心时也会派上用场)——然后把它转成 tensor(dsets.py:149, def getCtAugmentedCandidate)。
代码清单 14.11 读取 CT 数据并转换为 tensor 格式
def getCtAugmentedCandidate(
augmentation_dict,
series_uid, center_xyz, width_irc,
use_cache=True):
if use_cache:
ct_chunk, center_irc = \
getCtRawCandidate(series_uid, center_xyz, width_irc)
else:
ct = getCt(series_uid)
ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
接下来是仿射网格与采样代码(dsets.py:162, def getCtAugmentedCandidate)。
代码清单 14.12 应用仿射变换
transform_t = torch.eye(4)
# ... #1
# ... line 195
affine_t = F.affine_grid(
transform_t[:3].unsqueeze(0).to(torch.float32),
ct_t.size(),
align_corners=False,
)
augmented_chunk = F.grid_sample(
ct_t,
affine_t,
padding_mode='border',
align_corners=False,
).to('cpu')
# ... line 214
return augmented_chunk[0], center_irc
#1 对 transform_t 的各种修改会加在这里。
如果不加任何额外操作,这个函数本身其实不会做什么。下面我们就逐一看看,怎样把真正的变换加进去。
注意:非常重要的一点是,要把数据流水线组织成“先缓存,再增强”的顺序!如果反过来做,就会导致数据只被增强一次,然后以增强后的状态被持久化存下来,这就彻底违背了数据增强的初衷。
镜像翻转(Mirroring)
对样本做镜像翻转时,像素值本身完全不变,只改变图像的方向。由于肿瘤生长与左右方向或前后方向之间并没有很强的相关性,因此我们应该可以安全地在这两个方向上做翻转,而不破坏样本的代表性。index 轴(在病人坐标中对应 Z)则对应一个直立人体中的重力方向,所以肿瘤的“上”和“下”理论上可能存在差异。我们暂时假设没问题,因为快速目测并没有看到明显偏差。如果这是一个要走到临床可用程度的项目,我们就必须向领域专家确认这个假设是否成立(dsets.py:165, def getCtAugmentedCandidate)。
代码清单 14.13 通过翻转图像来做增强
for i in range(3):
if 'flip' in augmentation_dict:
if random.random() > 0.5:
transform_t[i,i] *= -1
grid_sample 会把 [-1, 1] 这个范围映射到新旧 tensor 的边界范围内(如果尺寸不同,缩放会自动隐式完成)。正因为是这种范围映射,想做镜像翻转其实非常简单:只需要把变换矩阵中对应轴上的对角元素乘以 -1 即可。
随机偏移(Shifting by a random offset)
把结节候选稍微移动一点,不应该造成太大问题,因为卷积本身具有一定平移不变性。不过,这样做会让模型对“结节没有恰好居中”的情况更加鲁棒。真正会带来更明显变化的是:这个 offset 不一定是整数个体素;因此数据会通过三线性插值重新采样,这会引入一点轻微模糊。样本边缘处的体素会被重复,因此边界处可能会出现一种拖影、涂抹式的边缘(dsets.py:165, def getCtAugmentedCandidate)。
代码清单 14.14 应用随机空间偏移以增强平移鲁棒性
for i in range(3):
# ... line 170
if 'offset' in augmentation_dict:
offset_float = augmentation_dict['offset']
random_float = (random.random() * 2 - 1)
transform_t[i,3] = offset_float * random_float
注意,'offset' 参数表示的是最大偏移量,它采用的就是 grid_sample 所使用的 [-1, 1] 归一化尺度。
缩放(Scaling)
轻微缩放图像,与前面的翻转和偏移本质上很相似。它同样会在边缘产生前面提到的那种“边缘体素重复”现象(dsets.py:165, def getCtAugmentedCandidate)。
代码清单 14.15 应用随机缩放变换
for i in range(3):
# ... line 175
if 'scale' in augmentation_dict:
scale_float = augmentation_dict['scale']
random_float = (random.random() * 2 - 1)
transform_t[i,i] *= 1.0 + scale_float * random_float
由于 random_float 已经被映射到了 [-1, 1],所以无论你把 scale_float * random_float 加到 1.0 上,还是从 1.0 中减去,本质上都没有区别。
旋转(Rotating)
旋转是我们第一次必须仔细考虑数据结构本身的增强类型,因为如果处理不当,就有可能把样本变换成不再具有代表性的东西。回忆一下:CT 切片在 row 和 column 两个方向(也就是 X 轴和 Y 轴)上的间距是均匀的,但在 index(或 Z)方向上,体素不是立方体。因此,我们不能把这三个轴当成完全等价的。
一种可能的处理方式,是先对数据重新采样,使得 index 轴方向的分辨率与另外两个方向一致。但这其实不是真正的解决方案,因为那样一来这个轴上的数据会变得非常模糊、非常 smeared。即使插值得到更多体素,数据本身的 fidelity 也不会真正变好。所以,我们会把这个轴当成特殊情况处理,只允许旋转发生在 X–Y 平面中(dsets.py:181, def getCtAugmentedCandidate)。
代码清单 14.16 对图像进行旋转
if 'rotate' in augmentation_dict:
angle_rad = random.random() * math.pi * 2
s = math.sin(angle_rad)
c = math.cos(angle_rad)
rotation_t = torch.tensor([
[c, -s, 0, 0],
[s, c, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
])
transform_t @= rotation_t
噪声(Noise)
最后一种增强方式与前面几种有个本质不同:它会以一种“主动破坏”的方式修改样本,而翻转或旋转则不会。如果给样本加太多噪声,那么真实信号就会被彻底淹没,导致样本几乎不可能再被正确分类。虽然如果我们把偏移或缩放参数设得极端一些,它们也会产生类似问题,但我们这里所选的参数只会主要影响样本边缘。而噪声则会作用于整幅图像(dsets.py:208, def getCtAugmentedCandidate)。
代码清单 14.17 添加随机高斯噪声
if 'noise' in augmentation_dict:
noise_t = torch.randn_like(augmented_chunk)
noise_t *= augmentation_dict['noise']
augmented_chunk += noise_t
前面那些增强方式,是通过制造“新的但仍合理的视角”来增加数据集的有效规模;而噪声则是在让模型的任务变得更困难。等看到训练结果之后,我们会再回头讨论它。
检查增强后的候选样本
图 14.21 展示了前面努力的结果。左上角是一个未经增强的正类候选,接下来的五张则分别展示了每一种增强方式单独使用时的效果。最后,底部一行展示的是把所有增强方式组合在一起后,连续随机生成的三次结果。
图 14.21 在一个正类结节样本上施加的多种增强方式
由于增强版数据集的每一次 __getitem__ 调用都会重新随机执行增强,因此底部一排的每张图都长得不一样。这也意味着,几乎不可能再生成出完全一样的一张图!还要记得,有时 'flip' 增强也会“随机决定不翻转”。如果总是返回“翻过的图”,那其实和“从来不翻”一样,都太受限了。现在,我们来看看这些增强到底有没有带来效果。
14.6.2 观察数据增强带来的改进
我们准备再训练一些额外模型:前一节提到的每一种增强方式各跑一个模型,另外再跑一个“所有增强方式一起上”的模型。等这些模型训练完成后,我们就去 TensorBoard 里看数字。
为了能够灵活开关这些增强方式,我们需要把 augmentation_dict 的构造暴露到命令行接口上。我们的程序会通过 parser.add_argument(此处略去,但与前面已有参数的写法类似)引入这些选项,然后再由代码真正去构建 augmentation_dict(training.py:105, LunaTrainingApp.__init__)。
代码清单 14.18 通过命令行参数配置所有增强方式
self.augmentation_dict = {}
if self.cli_args.augmented or self.cli_args.augment_flip:
self.augmentation_dict['flip'] = True
if self.cli_args.augmented or self.cli_args.augment_offset:
self.augmentation_dict['offset'] = 0.1 #1
if self.cli_args.augmented or self.cli_args.augment_scale:
self.augmentation_dict['scale'] = 0.2
if self.cli_args.augmented or self.cli_args.augment_rotate:
self.augmentation_dict['rotate'] = True
if self.cli_args.augmented or self.cli_args.augment_noise:
self.augmentation_dict['noise'] = 25.0
#1 这些数值是通过经验挑出来的,效果算比较合理,但更优的值很可能依然存在。
现在这些命令行参数已经准备好了。你可以直接运行下面这些命令,或者回到 p2_run_everything.ipynb,执行第 8 到第 16 个单元。不管用哪种方式,训练都需要相当长时间:
$ .venv/bin/python -m p2ch14.prepcache #1
$ .venv/bin/python -m p2ch14.training --epochs 20 \
--balanced sanity-bal #2
$ .venv/bin/python -m p2ch14.training --epochs 10 \
--balanced --augment-flip sanity-bal-flip
$ .venv/bin/python -m p2ch14.training --epochs 10 \
--balanced --augment-shift sanity-bal-shift
$ .venv/bin/python -m p2ch14.training --epochs 10 \
--balanced --augment-scale sanity-bal-scale
$ .venv/bin/python -m p2ch14.training --epochs 10 \
--balanced --augment-rotate sanity-bal-rotate
$ .venv/bin/python -m p2ch14.training --epochs 10 \
--balanced --augment-noise sanity-bal-noise
$ .venv/bin/python -m p2ch14.training --epochs 20 \
--balanced --augmented sanity-bal-aug
#1 每章的缓存只需要预处理一次。
#2 如果你在本章前面已经跑过这个实验,那就不用再重复跑了!
当这些任务跑起来之后,我们就可以顺手启动 TensorBoard。为了只看这些 run,可以把 logdir 参数改成 ../path/to/tensorboard --logdir runs/p2ch14。
具体训练时间会非常依赖你手头的硬件。如果觉得耗时太长,你完全可以跳过 flip、shift 和 scale 这几个单独训练任务,也可以把第一和最后一个 20 epoch 的 run 降到 11 epoch,加快整体节奏。这里选择 20 epoch,主要是为了让它们在 TensorBoard 里和其他 run 更明显地区分开;11 个 epoch 通常也能看出趋势。
如果你让所有实验都完整跑完,那么 TensorBoard 中应该会出现类似图 14.22 那样的数据。为了减少画面混乱,我们会把除了验证集之外的所有 run 全部取消勾选。你在实际查看数据时,也可以调整 smoothing,这会有助于让趋势线更清晰。先快速看一眼图,然后我们再仔细分析。
图 14.22 在验证集上,使用多种增强方案训练得到的网络,其整体正确率、loss、F1、precision 与 recall 曲线
图 14.22 左上角那张图(tag: correct/all)里,首先要注意到的是:各个单独增强方式的结果其实相当杂乱。未增强 run 和“全部增强” run,正好落在这团混战的两侧。这说明:当各种增强方式叠加在一起时,其效果并不只是简单相加,而是会超出各自单独作用的总和。另一个值得注意的点是:完全增强版模型,错误样本其实更多。就总体来说,这当然不是好事;但如果我们去看右边那一列图(那些图聚焦的是我们真正关心的正类候选样本——也就是真正的结节),就会发现:完全增强版模型在发现这些正类候选方面好得多。它的 recall 非常漂亮!而且它在抵抗过拟合方面也明显更好。前面我们已经看到,未增强模型是会随着训练推进而越来越差的。
还有个有趣现象是:只加噪声增强的模型,在识别结节方面反而比未增强模型更差。这和我们之前说过的一致:噪声会让模型的任务变得更难。
另一个在实时数据里很有意思、但在这张图里因为太拥挤而不太明显的现象是:旋转增强模型的 recall 几乎和完全增强模型一样好,而它的 precision 却更好。由于我们的 F1 score 在本任务里主要是 受 precision 限制的(因为负样本远多于正样本),所以旋转增强模型的 F1 反而更高。
不过,后续我们仍然会沿用“完全增强”的模型,因为我们的实际 use case 要求高 recall。F1 仍将用于决定“保存哪个 epoch 作为最佳模型”。如果是一个真实项目,我们很可能会额外投入时间,去进一步探索:不同增强类型和参数值的组合,是否能带来更优结果。
14.7 结论
这一章里,我们花了很多时间和精力,去重构自己理解模型性能的方式。糟糕的评估方法非常容易把人带偏,而对“哪些因素决定了模型评价是否合理”形成强直觉,是极其重要的。一旦这些基本功内化之后,你就会更容易在实际项目中发现:自己是不是正在被错误指标误导。
我们还学会了如何应对那些样本量不够充足的数据源。能够合成出有代表性的训练样本,是一项非常有用的能力。说实话,训练数据多到用不完的情况,在现实里反而很少见。
现在,既然我们已经有了一个表现 reasonably 不错的分类器,接下来就要把注意力转向“自动发现候选结节”本身。第 15 章就会从这里开始。
14.8 练习
F1 score 实际上可以推广到不止 1 这一种情况:
- 阅读 en.wikipedia.org/wiki/F-scor… ,并实现 F2 与 F0.5 score。
- 判断在这个项目中,F1、F2 和 F0.5 哪一个最合理。跟踪这个值,并将其与 F1 对比。(是的,这已经在暗示你:答案不一定是 F1!)
为 ratio_int = 0 的 LunaDataset 实现一种基于 WeightedRandomSampler 的正负样本平衡方式:
- 你是如何获得每个样本类别信息的?
- 哪种实现更容易?哪种代码更可读?
试验不同的类别平衡方案:
- 什么 ratio 能在 2 个 epoch 后拿到最好分数?在 20 个 epoch 后呢?
- 如果 ratio 变成
epoch_ndx的函数,会怎样?
试验不同的数据增强方案:
- 现有增强里,能否把某些方式做得更激进一些(例如 noise、offset 等)?
- 噪声增强究竟是帮助还是损害训练结果?调整参数值后,结论会不会改变?
- 去看看其他项目中常见的数据增强方法。这里有适用的吗?试着为正类结节候选实现 “mixup” 增强。它有帮助吗?
把初始归一化从 nn.BatchNorm 换成某种自定义方式,再重新训练模型:
- 使用固定归一化,能得到更好结果吗?
- 什么样的归一化偏移和缩放更合理?
- 像开平方这类非线性归一化有帮助吗?
TensorBoard 除了本章提到的内容之外,还能展示哪些类型的数据?
- 它能显示网络权重的信息吗?
- 那么,能不能显示模型在某个特定样本上的中间结果?如果把模型 backbone 包成一个
nn.Sequential实例,会更有帮助还是更碍事?
小结
- 一个二元标签加上一个二元分类阈值,会把数据集划分成四个象限:true positives、true negatives、false negatives 和 false positives。这四个量构成了我们改进后性能指标的基础。
- Recall 衡量的是模型尽可能找到 true positives 的能力。把所有东西都选中,当然能保证 recall 完美(因为所有正确答案都包含进来了),但 precision 会极差。
- Precision 衡量的是模型尽量减少 false positives 的能力。什么都不选,当然能保证 precision 完美(因为不会包含任何错误答案),但 recall 会极差。
- F1 score 把 precision 和 recall 结合成一个单一指标,用来描述模型整体表现。公式是:
F1 = 2 * (precision * recall) / (precision + recall)。我们用 F1 score 来判断训练或模型改动到底对性能产生了什么影响。 - 类别极度不平衡的数据(例如负样本远多于正样本)会导致退化行为:模型只预测多数类。这会让少数类表现非常差,因此必须通过平衡或其他技术来处理。
- 在训练阶段把训练集平衡成正负样本数量相当,通常会得到更好的模型表现(这里所谓更好,是指能得到一个正的、并且持续增长的 F1 分数)。
- 当模型在训练集上表现很好、在未见数据上却表现糟糕时,就发生了过拟合。这意味着模型记住了训练样本,而没有学到可泛化的规律。过拟合的典型迹象,是训练集与验证集在 accuracy 或 F1 等指标上出现明显差距。
- 数据增强会对已有的自然样本做修改,使增强后的样本与原样本非平凡地不同,但仍保持同一类别的代表性。这样,在数据有限时,我们就能在不过拟合的前提下继续训练。
- 常见的数据增强策略包括:改变方向、镜像翻转、缩放、偏移,以及添加噪声。根据项目不同,也可能存在其他更有针对性的增强方式。