本章涵盖以下内容:
- 修改数据,使其可用于一个二维分割问题
- 使用 Segment Anything 执行分割
- 理解使用 SegFormer 进行掩码预测
- 微调一个分割模型
在前四章中,我们已经完成了很多工作。我们学习了 CT 扫描与肺部肿瘤、数据集与数据加载器、指标与监控;我们也把第一部分中学到的许多内容实际应用了起来,并且已经有了一个可用的分类器,用来判断 CT 扫描中的候选区域究竟是结节还是非结节。
然而,我们目前仍然处在一个多少有些“人为构造”的环境中,因为我们还依赖人工标注的结节候选信息(也就是 annotations.csv 和 candidates.csv 文件)来把候选区域送入分类器。那么,怎样才能自动判断 CT 扫描中的哪些部位是结节候选呢?这正是本章要解决的问题。
正如我们在开头解释过的,我们的项目采用多阶段方式来解决“定位可能的结节并将其识别为结节或非结节”这一问题。这是实践者中很常见的一种做法;而在深度学习研究中,则更倾向于展示单个模型端到端解决复杂问题的能力。本书采用的这种多阶段项目设计,正好为我们一步一步引入新概念提供了非常好的理由。
在本章中,我们将聚焦于一个与前几章不同的问题。出于教学与演示目的,你可以把这一章看作一个相对独立的部分。我们会引入一个新模型、一个新数据集,以及一套新的训练循环。我们的方法将会使用流行的开源模型,并学习如何将其中一个模型微调到适合我们的具体用例。我们会使用流行的开源模型,并学习如何针对我们的场景对它进行微调。前面几章虽然会提供有帮助的背景知识,但并不是理解本章的必要前提。本章对于理解整个端到端流水线也很有价值。现在就开始吧!
15.1 在项目中引入第二个模型
在前两章中,我们一直在处理图 15.1 所示计划中的第 3 步:分类。而在本章中,我们要回到第 2 步(见图 15.1)。我们需要找到一种方法,告诉分类器应该去哪里看。为此,我们要从原始 CT 扫描出发,找出一切可能是结节的区域。为了找到这些潜在结节,我们必须把那些看起来可能属于结节的体素标记出来,这一过程称为 segmentation(分割) 。
注意:我们预期会标出相当多并非结节的东西;这也正是我们设置后续分类步骤来减少这些“误报”的原因。
图 15.1 我们的端到端肺癌检测项目,本章聚焦于其中的主题:第 2 步,分割。
在本章的数据处理中,我们将基于 CT 扫描的二维切片来操作。这样做更便于演示,也更容易把分割结果可视化出来。我们当然也可以直接操作三维数据,但那样更难解释,也更难可视化。
把图 15.2 拆成步骤来看,本章的计划如下:
-
Segmentation(分割) ——首先,我们将学习如何借助 Segment Anything 模型来完成分割,包括这个新模型包含哪些组件,以及在整个分割过程中这些组件分别发生了什么。这对应图中的第 1 步。
-
Update(更新) ——为了真正把分割实现出来,我们需要在现有代码库中改动三个主要位置,这些改动显示在图右侧的子步骤中。整体结构会与我们为分类任务开发的代码非常相似,但细节会有所不同:
- 利用开源模型(step 2A) ——我们将下载并集成 Segment Anything 模型。第 14 章中的模型输出的是一个简单的 true/false 分类结果;而本章中的模型则会输出整张图像级别的结果。
- 修改数据集(step 2B) ——我们需要修改数据集,不仅要返回 CT 图像本身,还要同时提供结节的 mask(轮廓)。分类数据集由围绕结节候选区域裁出的三维小块组成,而分割训练与验证则需要收集完整 CT 切片以及二维裁剪块。
- 调整训练循环(step 2C) ——我们需要引入一个新模型,用我们的数据对它进行微调,并输出分割结果。
-
Results(结果) ——最后,我们会查看定量分割结果,看看前面这些努力带来了什么成果。
图 15.2 新的分割模型架构,以及我们将要实现的模型、数据集与训练循环更新。
15.2 分割的几种类型
开始之前,我们先来看看分割的几种不同类型。对于本项目,我们关注的是 semantic segmentation(语义分割) ,也就是对图像中的每一个像素逐个进行分类,分类标签与我们之前在分类任务中见过的那些标签类似,比如 “bear”“cat”“dog”等等。如果做得足够好,最终就会得到一些清晰的区域块,表示诸如“这些像素全部属于一只猫”。它通常表现为一个标签 mask 或热力图,用于指出感兴趣区域。对我们来说,标签会非常简单,是一个二元标签:True 表示该位置对应结节候选,False 表示无关的健康组织。这样一来,我们在一定程度上就可以满足“找出后续将送入分类网络的结节候选区域”的需求。
在深入细节之前,我们先简单讨论一下其他也可以用来寻找结节候选的方法。比如,instance segmentation(实例分割) 会给每一个单独的目标对象赋予不同的标签。在一幅包含多个肿瘤的医学图像中,实例分割会分别给每个肿瘤单独打标签,区分出 “tumor1” 和 “tumor2”。虽然这对于某些应用很有帮助,但相比语义分割,它通常更复杂、计算代价也更高;语义分割只关心把像素分到某个类别里,而不区分同类中的不同实例。
另一类常见方法是 object detection(目标检测) 。它会在图像中找到感兴趣对象,并画出一个包围框。虽然这对于识别某种特征是否存在很有用,但它缺乏医学分析所需的那种逐像素精度。鉴于我们的项目既需要精度又需要效率,因此这里我们选择聚焦于语义分割。
15.3 语义分割:逐像素分类
很多时候,分割要回答的问题形式是:“这张图里猫在哪里?” 很明显,大多数猫的照片——比如图 15.3——里面包含了大量“不是猫”的内容;背景里可能有桌子、有墙,还有猫坐着的键盘之类的东西。想要做到“这个像素属于猫,而那个像素属于墙”,其模型输出形式和内部结构都必须与我们此前使用过的分类模型根本不同。分类只能告诉我们图中有没有猫,而分割则会告诉我们猫在哪里。
图 15.3 分类会输出一个或多个二值标记,而分割会输出一个 mask 或热力图。
如果你的项目需要区分“近处的猫和远处的猫”,或者“左边的猫和右边的猫”,那么分割很可能就是更合适的方法。我们目前为止实现过的图像分类模型,可以被理解成一种漏斗,或者说放大镜:它把大量像素逐步聚焦到一个单点上(更准确地说,是聚焦到一组类别预测上),如图 15.4 所示。分类模型给出的答案是类似“是的,这一大堆像素里 somewhere 有一只猫”,或者“没有,这里没有猫”。当你不关心猫具体在哪里,只关心图里是否存在一只猫时,这就非常合适。
图 15.4 用于分类的“放大镜式”模型结构。
不断堆叠卷积层与下采样层,意味着模型最开始是从原始像素出发,逐步学会一些具体而细粒度的检测器,比如纹理、颜色;然后再继续向上构建更高层的概念性特征检测器,比如眼睛、耳朵、嘴巴和鼻子,最终汇总成“cat”还是“dog”的判断。随着每次下采样之后卷积层感受野不断增大,这些高层特征检测器就能利用来自输入图像越来越大区域的信息。
然而,分割要求模型输出的仍然是一个“像图像一样”的结果,因此最后只得到一个类似分类的二值标记列表显然不够。为此,我们会引入一个新模型。它的架构就是专门为处理分割任务设计的,其中利用了高级技术,例如 attention 机制和 transformer 类模型(我们在第 9 章中讲过)。这使它能够高效地处理并分割图像,在不丢失空间信息的同时,把注意力聚焦到图像中真正相关的部分。
15.3.1 Segment Anything 模型(SAM)
我们将用于这个语义分割问题的新模型,叫做 Segment Anything model(SAM) (参见 arxiv.org/pdf/2304.02…)。它由 Meta AI 于 2023 年发布并开源。该模型基于 Transformer 架构,并在一个包含大量图像及其对应分割 mask 的大规模数据集上完成训练。它能够处理非常广泛的对象类型,并且对新图像具有很强的泛化能力。
使用 Segment Anything,你可以上传一张图像,并且:
- 为图像中的所有对象创建分割 mask
- 接受点、框或 mask 形式的提示(prompt)
- 处理多种多样的对象,并对新图像具有良好的泛化效果
这个模型一经发布,就因为其在众多领域中的潜在应用而引发了极大关注,包括自动驾驶、机器人、图像编辑(去背景或去物体)、标注辅助等等。就像 PyTorch 曾经在深度学习领域带来关键性转变一样,Segment Anything 也很可能会对图像分割领域产生类似的变革性影响。基础模型(foundational models)已经在自然语言处理中取得了巨大成功,从 2018 年的 BERT 到后来的 ChatGPT 这类大语言模型皆是如此。然而,计算机视觉长期以来一直难以找到一种同样足够通用、能广泛适用于多种任务的架构。此前的模型虽然在特定任务上表现出色,但想迁移到新任务或更复杂场景时,往往需要大量重新训练。因此,SAM 是第一个真正展示出“图像分割基础模型”潜力的模型。
Segment Anything 的影响力,体现在它具备 zero-shot segmentation(零样本分割) 的能力,也就是说,它可以在无需针对特定数据集额外训练或微调的情况下,对图像中的对象进行分割。这相比以往那些通常需要任务专属训练数据才能取得最佳表现的分割模型,是一次重大进步。
这种方法其实借鉴了更早的自然语言处理模型:当它们在大规模数据集上训练完成之后,就可以通过 prompt 去执行很多不同任务。SAM 的作者识别出了一个合适的任务形式、一个合适的模型架构以及一个合适的数据集,并在 1100 万张图像与 11 亿个 mask 上对它进行了训练。本章中,我们将使用这个高度通用的模型,去对 CT 扫描中的结节进行分割。接下来,先来看看它的模型架构,然后再通过一个例子走一遍它的使用方式。
15.4 SAM 架构
SAM 使用的是一种基于 transformer 的架构,它由三个主要组件构成,如图 15.5 所示:
- Image encoder(图像编码器) ——这一部分负责从输入图像中提取特征。它采用的是一种 Vision Transformer(ViT)架构(参见 arxiv.org/pdf/2010.11…),会先把图像切分成 patch,然后利用 self-attention 机制对这些 patch 进行处理,以同时捕获局部与全局信息。
- Prompt encoder(提示编码器) ——这一部分负责对模型收到的 prompt 进行编码,例如点、框(稀疏 prompt)以及 mask(稠密 prompt)。就像图像会被表示成一系列数字供模型处理一样,这些 prompt 也会被编码成数值表示,以便驱动模型。它结合了卷积神经网络(CNN)和 transformer,对这些提示进行处理,并生成后续可以使用的 embedding。
- Mask decoder(掩码解码器) ——这一部分接收来自图像编码器的特征以及来自提示编码器的 embedding,并生成最终的分割 mask。它的设计灵感来自 transformer decoder 架构,会把这些信息结合起来,并在图像特征与 prompt embedding 之间执行 self-attention 与 cross-attention。Mask decoder 负责产出最终的分割结果,指出图像中哪些区域是感兴趣区域。
图 15.5 SAM 架构的高层示意图。为简洁起见,省略了本章不会用到的一些部分。
为了处理“单个 prompt 可能存在多种合理输出”的模糊性,模型会同时输出三个不同大小的 mask,以及对应的置信度分数。这三个 mask 通常是嵌套关系的:模型分别输出 whole、part 和 subpart。比如在图 15.5 的狗图像中,SAM 的输出 mask 就分别覆盖了整只狗、狗的鼻子,以及鼻子中的一个更小子区域。
SAM 的实现仓库位于:github.com/facebookres…。虽然本章中我们会使用一个已经训练好的版本,但这个仓库同样也提供了从头训练模型的代码。很多时候,一边阅读论文一边对照代码,会非常有助于真正理解模型架构及其工作原理。虽然本章不会从零手写这样一个模型,但我们仍然建议你花一点时间去检查一下代码,并基于你目前已经建立起来的知识体系,尝试识别出架构中各个 building block 在代码中的具体对应。此外,在使用开源模型时,也必须考虑模型所附带的许可证。
许可证(LICENSES)
对于个人项目来说,许可证问题通常没那么敏感,但无论如何,都应当清楚你所使用的开源软件附带了什么许可条款。Segment Anything 使用的是 Apache 2.0 许可证,这是一种较为宽松的开源许可证,它给予用户很大的使用自由,但同时也有一些要求,比如必须提供适当署名,并在重新分发代码时附上许可证副本。
另外,作者即便把作品发布在公开平台上(是的,即便是在 GitHub 上),版权依然归作者所有;如果他们没有附带许可证,那并不意味着作品自动进入公有领域。恰恰相反!这意味着你其实没有任何合法使用该代码的许可——就像你从图书馆借来一本书,并不意味着你有权整本拿去复制一样。还有一点需要注意:模型代码的许可证,可能与预训练权重的许可证并不相同,因此这两者都要分别留心。
既然我们已经找到了一个看起来非常适合当前问题的模型,接下来就需要把它适配到我们的具体需求中。一般来说,时时留意“有没有现成方案可以直接拿来用”总是好主意。你需要逐渐建立一种感觉:现有哪些模型、它们是如何实现和训练的、其中哪些部分可以“拿来主义”并应用到手头项目中。这样的广泛知识积累当然需要时间和经验,但越早开始搭建这套工具箱越好。
15.4.1 先试试一个现成模型,看看能不能直接用在我们的项目里
没有什么比亲手试一遍更能帮助理解模型的工作方式了,所以我们就通过一个示例来跑一遍。模型作者已经完成了绝大多数艰难工作,并提供了可供开源使用的预训练权重。前几章里,我们是通过 Hugging Face 的 transformers 库来加载预训练模型的;那个库里收录了很多开源模型。而在本章中,我们将直接从源代码仓库安装模型。
让我们打开 1_segment_example.ipynb,看看如何做到这一点。我们可以通过 pip install 从 GitHub 上的 Segment Anything 仓库(github.com/facebookres…)安装它(1_segment_example.ipynb)。
代码清单 15.1 从 GitHub 安装 Segment Anything
pip install git+https://github.com/facebookresearch/segment-anything.git
接着,我们导入这次演示要用到的库:
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from p2ch15.utils import get_sam_model
然后,我们会加载模型的某一个特定 checkpoint,其中包含预训练权重。Segment Anything 有三种不同大小的变体(“huge”“large”“base”);默认情况下,我们会加载中等尺寸的那个模型。我们已经实现了一个辅助方法 get_sam_model(),它会包含下载权重并把权重保存到本地 .pth 文件所需的 URL 信息。
我们还会设置运行设备:如果有 GPU,就使用 GPU,否则退回 CPU:
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_config, model_weights_path = get_sam_model()
sam = sam_model_registry[model_config](checkpoint=model_weights_path)
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
SamAutomaticMaskGenerator 这个抽象,是 Segment Anything 库中一个很方便的封装,用来帮助从模型中生成 mask。它会自动处理生成 mask 所需的预处理和后处理步骤。
接下来,我们把第 10 章中的宇航员图片作为示例图像输入,来做分割:
image_path = "data/p2ch10/astronaut.png"
image = Image.open(image_path).convert("RGB")
现在,我们就可以把它送进 mask 生成流程里,生成所有的 mask:
# In[]
image_array = np.array(image)
masks = mask_generator.generate(image_array)
print(len(masks))
print(masks[0]["segmentation"].shape)
print(masks[0]["predicted_iou"])
# Out[]
79
(1024, 1024)
1.033138632774353 #1
#1 IoU 指的是 intersection over union(en.wikipedia.org/wiki/Jaccar…),理论上应该位于 0 到 1 之间。不过,这里模型其实预测的是一个“置信度分数”,而不是实际 IoU,因此它可以大于 1(说明它“非常有信心”)。
这个模型一共为图像生成了 79 个 mask,每一个 mask 的尺寸都和原始图像一样大。这些 mask 是二值数组,其中每个元素对应图像中的一个像素:True 表示该像素属于某个对象,False 表示不属于。相应的 score 张量则包含每个 mask 的置信度分数,用来反映模型对该分割结果的把握程度。
这些 mask 是根据图像中提取出来的特征生成的,如图 15.6 所示。例如,如果我们把其中一个 mask 画出来,就能看到宇航员的轮廓:
def plot_mask(mask):
int_mask = mask.astype(int)
plt.imshow(int_mask, cmap='gray', interpolation='nearest')
plt.axis('off')
plt.show()
plot_mask(masks[3]["segmentation"])
图 15.6 原始宇航员图像的一个 mask。白色像素表示 mask 中值为 True 的位置,黑色像素表示值为 False 的位置。
我们也可以把所有 mask 一起叠加画出来,看看它们是如何覆盖整张图像的(图 15.7):
plot_image_with_masks(image, masks["masks"])
图 15.7 多个 mask 叠加在宇航员图像上的效果
多么绚丽的“被分割过的宇航员”!我们现在已经能为整张图像生成 mask,但如果直接对模型提出更具体的问题,我们还能做到更加精细的分割。
15.5 直接使用 SAM 模型
前面我们使用了 SamAutomaticMaskGenerator 工具,它会为整张图像自动生成所有可能的 mask。而在我们的场景里,我们真正想要的是把 CT 扫描中的结节单独拎出来。为此,我们需要给模型提供某种额外指示,告诉它:“你应该在这里附近做分割。”
因此,我们不再使用 SamAutomaticMaskGenerator,而是直接导入并使用 Segment Anything 库暴露出来的模型组件。这一点在后面我们想要把模型微调到适配自身任务时,也会变得很重要。
前面已经提到过,SAM 可以接收点提示、框提示或 mask 提示作为输入。接下来我们先演示如何用点提示来驱动模型。首先,先加载模型;下面这段代码省略了若干导入和工具函数相关内容(2_point_prompt.ipynb)。
代码清单 15.2 加载 SAM
from segment_anything import SamPredictor, sam_model_registry
model_config, model_weights_path = get_sam_model()
sam = sam_model_registry[model_config](checkpoint=model_weights_path)
sam.to(device)
predictor = SamPredictor(sam)
接下来,我们需要提供一个输入点,用来告诉模型图像中的感兴趣区域在哪里。正如图 15.5 所示,这个点会先经过 prompt encoder,被编码为一种 embedding 表示。与此同时,图像会被切成若干 patch。这两个中间表示——embedding 和 patch——随后会被组合起来一起送入模型。这个预处理步骤与模型真正执行的部分是分开的,因此既灵活又高效。这一次,我们换用第 2 章中的那张狗的图片,并选择点 (320, 260) 作为输入提示:
image_path = "data/p1ch2/bobby.jpg"
image = Image.open(image_path).convert("RGB")
input_points = [[(320, 260)]]
input_points = np.array([(320, 260)])
predictor.set_image(np.array(image))
masks, _, _ = predictor.predict(input_points, point_labels=np.array([1]))
这会得到与前面类似的一组 mask 输出。模型会根据输入点生成 3 个不同大小的 mask(whole、part 和 subpart)。我们可以把图像和输入点一起画出来,看看这些 mask:
plot_mask_with_point(masks[0], input_points[0])
模型之所以支持多个输入 prompt,是因为它的输入本身就是按“多点”来设计的。这里我们只用了一个点,因此才会有 [0] 这种索引方式。根据这个提示点生成出来的狗鼻子 mask 如图 15.8 所示。
图 15.8 带有点提示的 mask
真是“鼻子够灵”的模型!至此,我们已经成功地把 Segment Anything 当作一个现成模型直接用起来了。不过,我们可不是来做“狗和狗鼻子的分割生意”的。我们的目标,是把目前看到的这些方法迁移到结节分割问题上,而这个问题属于一个完全不同的领域:医学影像。
通用图像模型 vs. 医学影像领域
事实上,在医学影像中做分割,有很多聪明而创新的方法;我们绝不是说这里只展示的就是唯一有效路线。必须明确指出,“自然图像”(也就是普通照片)与“医学图像”之间,存在明显的 domain gap(领域差异) 。
Segment Anything 是为通用图像分割设计的,因此它是否能在医学图像上同样表现出色,其实并不显然。实际上,已经有研究者在探索这个挑战,其中一些工作甚至专门针对医学影像任务对 SAM 做了改造(例如 arxiv.org/html/2304.1…)。
不过,我们认为,对于本书当前这个项目的目标而言,这种方法已经是最简单且足够好用的一种方案。我们更愿意先把事情保持在简单层面,这样就能把注意力集中在概念本身;至于那些更精巧的做法,等你掌握了基础之后,再去深入也不迟。
现在,既然我们已经拥有了“分割工具”,就可以把重点转回最初目标:对结节进行分割。
15.6 为分割任务更新数据集
本章的数据源其实并没有变:我们依然是在消费 CT 扫描及其注释信息。变化的是,我们的模型现在期望的输入形式与输出形式,已经不同于之前。之前的数据集输出的是三维数据,而现在我们需要的是二维数据。图 15.9 给出了我们目前所处位置的整体示意。
图 15.9 本章的整体脉络,此处聚焦于为了分割任务而需要对数据集做出的修改
15.6.1 绕开 SAM 只能处理二维数据的限制
第一个问题在于:我们的原始数据是体数据形式的三维图像,而 SAM 模型却是为二维图像处理而设计的。SAM 只能直接处理二维图像,无法直接消费三维体积数据。为了把 SAM 应用于我们的分割任务,我们必须先从三维 CT 扫描中抽取出二维切片,再逐张送入模型。
这种做法当然不是没有代价。由于 CT 切片通常在厚度方向上的分辨率低于行列方向,所以虽然我们得到的“视野宽度”会比表面上看起来稍大一些,但考虑到结节本身通常只会跨越有限数量的切片,这应该已经足够。无论是当前的二维方案,还是未来完整三维方案,都还有一个相关问题:我们暂时忽略了切片厚度的精确值。这意味着,模型最终必须学会对不同切片间距具备鲁棒性,也就是通过见到具有不同 spacing 的数据来适应它们。
总之,我们现在会使用 CT 扫描的二维图像切片来工作,而这在医学影像中其实是很常见的做法。通过把每一张切片视为一张独立图像,我们就能够使用 SAM 的能力,去分割出我们想要的结节。我们的目标是:给定一张 CT 扫描切片的二维图像,自动把感兴趣区域分割出来,以供后续分类模型使用。图 15.10 展示了一个例子。
图 15.10 带有结节分割 mask 的肺部 CT 扫描切片
为了收集这些图像,我们会把每一张切片看作空间中的一个独立位置,正如图 15.11 所示。图中给的是一个从头顶向下看的头颅示意,随着切片位置变化,会依次看到大脑、眼睛、牙齿,从上方一路经过中部到底部。类似地,我们也会以这样的方式去查看肺部的 CT 扫描。
图 15.11 CT 扫描中的每一张切片,都代表空间中的一个不同位置
为了微调我们的 SAM 模型,我们需要收集 CT 扫描切片图像,以及与之对应的感兴趣 mask。有了这些图像,我们就可以进一步训练模型,让它更准确地分割出我们需要的结节。下面开始构建用于分割的数据集。
15.6.2 构建分割数据集
首先要解决的一个问题是:我们的人工标注数据,与模型真正想输出的东西之间并不匹配。我们手头拥有的是标注点,但我们真正想要的是一个 mask,它能够指出图像中的任意给定像素是否属于结节。也就是说,我们必须借助基础版 SAM 和原始数据集中提供的点,自行构造出这个 mask。
我们将继续沿用前几章用过的 LunaDataset,只是加上一些小修改。首先,我们来创建一个辅助方法:利用候选信息中的 series UID 和 xyz 坐标,从 CT 数据中取出一张切片(dsets.py:113, Ct.getSingleSlice)。
代码清单 15.3 从三维 CT 扫描中提取单张二维切片
center_irc = xyz2irc(
center_xyz, #1
self.origin_xyz,
self.vxSize_xyz,
self.direction_a,
)
center_val = int(round(center_irc[axis])) #2
if axis == 0:
ct_slice = self.hu_a[center_val, :, :] #3
elif axis == 1:
ct_slice = self.hu_a[:, center_val, :]
elif axis == 2:
ct_slice = self.hu_a[:, :, center_val]
else:
raise ValueError("Invalid axis value. Must be 0, 1, or 2.")
return ct_slice, center_irc
#1 center_xyz 是从 candidateInfo_tup 中取得的,正如之前 getCandidateInfoList 返回的那样。
#2 获取切片应该从哪个中心体素索引位置取出。
#3 从三维数据中按索引取出对应的切片,从而得到后续要用的二维数据。
我们先取出中心点,再利用这个点从 CT 数据中索引出一张切片。接着,就可以从我们的 Dataset 子类里返回全部相关信息(dsets.py:211, Ct._get_single_item)。
代码清单 15.4 从数据集中取回 CT 切片及其元数据
candidateInfo_tup = self.candidateInfo_list[ndx]
ct_slice, center_irc = getCtSlice(
candidateInfo_tup.series_uid,
candidateInfo_tup.center_xyz,
)
ct_slice_tensor = torch.from_numpy(ct_slice).to(torch.float32)
pos_t = torch.tensor([
not candidateInfo_tup.isNodule_bool,
candidateInfo_tup.isNodule_bool
],
dtype=torch.long,
)
return ct_slice_tensor, pos_t, candidateInfo_tup.series_uid, \
torch.tensor(center_irc)
当我们从数据集中取一个样本出来时,可以先检查这张切片的 shape,确保它确实是二维的,并且确认其他信息也都正确(3_segment_ct_slice.ipynb)。
代码清单 15.5 验证 CT 切片的维度
# In[]
dataset = LunaDataset(sortby_str="label_and_size")
ct_slice, nodule_class, series_uid, center_irc = dataset[3]
ct_slice.shape
# Out[]
torch.Size([512, 512])
为了可视化,我们还可以把切片和对应的结节候选点一起画出来:
plot_single_slice(ct_slice, center_irc)
图 15.12 给出了一个肺部 CT 切片示例,其中结节位置被特别标出。
图 15.12 一张肺部 CT 扫描切片,其中结节位置用星号标出。
在前几章中,我们之所以缓存以结节候选为中心的小块 CT 数据,是因为我们不想每次只取一小块时,都重新从磁盘读取并解析整份 CT。现在对新的 ct_slice 也是同样道理,我们仍然希望使用缓存(dsets.py:151)。
代码清单 15.6 如何缓存 CT 切片
@raw_cache.memoize(typed=True)
def getCtSlice(series_uid, center_xyz):
ct = getCt(series_uid)
ct_slice, center_irc = ct.getSingleSlice(center_xyz)
return ct_slice, center_irc
prepcache 脚本会提前把这些值计算并保存好,从而帮助训练过程保持快速。现在既然切片已经准备好了,我们就可以开始构造微调所需的数据集了。
15.6.3 训练一个能够标出潜在候选区域的模型
在开始训练之前,我们需要两样东西:二维图像,以及与之对应的分割 mask。上一节已经拿到了二维图像,而 mask 还没有。我们可以使用刚才见过的那个预训练 SAM 模型来生成这些 mask。
在当前任务中,我们并不追求像素级“完美无瑕”的 mask。更重要的是:要能够从扫描图像中准确找出正确的 mask。上一章里构建的分类模型,其职责是判断某个候选区域是不是结节;而分割步骤的职责,是先从海量可能性中把所有潜在候选找出来。
为了生成这些 mask,我们将继续使用前面演示过的 SamPredictor 类。我们会把二维 CT 切片图像和对应点位置一起喂给模型,由模型生成 mask,而这些 mask 之后就会成为我们微调数据集中的“ground truth”。
准备 CT 图像与 mask
简单回顾一下:我们的微调数据集将需要两类图像——CT 切片图像,以及我们希望分割出来的结节对应的 mask 图像。只要有了这两种图像,你就可以训练很多不同的分割或目标检测模型,比如 U-Net、Mask R-CNN、DeepLab、YOLO 等等。本章中,我们会使用 SegFormer,这是一个基于 transformer 的模型,轻量且高效,非常适合语义分割任务。稍后我们会更详细地介绍它。
SAM 生成出来的这些 mask,为我们提供了一个低成本却高质量的训练集;如果完全靠人工来标注出这种数据,成本会非常高。我们正是利用这套数据,去微调一个轻量得多的 SegFormer 模型,让它能够在没有任何 prompt 的情况下,自动对 CT 切片做分割。
至于如何组织这套微调数据集,其实有很大自由度。一种做法是维护一个包含成对图像的大文件夹。归根结底,采用哪种组织方式并不关键,只要对自己(以及未来可能分享给别人时对别人)来说清晰易懂就行。我们这里采用两个独立文件夹的方式:一个保存 CT 图像,一个保存 mask 图像。另外,再配一个元数据文件,把图像与 mask 的详细信息串起来,如图 15.13 所示(4_create_dataset.ipynb)。
代码清单 15.7 为新的数据集建立目录结构
fine_tuning_dir = "data-unversioned/part2/fine-tuning/dataset"
ct_folder = f"{fine_tuning_dir}/ct"
mask_folder = f"{fine_tuning_dir}/mask"
图 15.13 我们的数据集结构:包含 CT 图像和对应的 mask
我们会创建一个辅助方法,自动生成这些文件夹,并把图像、mask 和元数据填充进去。这个过程将会复用我们之前开发好的 LunaDataset 中的数据:
def generate_ct_images_and_masks(original_ct_data, max_dataset_size=500, \
recompute=False):
...
for ct_slice, _, series_uid, center_irc in original_ct_data:
...
我们编写了一个 generate_ct_images_and_masks 方法,它接收一个数据集对象作为输入,这里我们会传入 LunaDataset。同时,这个方法还会检查图像和 mask 是否已经生成过;如果已经生成过,就会跳过当前实例。
为了创建 CT 图像,我们先把 tensor 转成 NumPy 数组,再把维度顺序从 (C, H, W) 调整成 PIL 图像库所期望的 (H, W, C):
ct_filepath = os.path.join(ct_folder, filename)
ct_image_array = np.transpose(scaled_ct_slice.numpy(), (1, 2, 0))
ct_image = Image.fromarray(ct_image_array, mode="RGB")
...
ct_image.save(ct_filepath)
在生成 mask 图像时,我们会使用 SAM,并提供一个输入点作为 prompt。这个过程会返回一个二维布尔数组(True / False),代表 mask,然后我们再把它保存成图像:
mask_filepath = os.path.join(mask_folder, filename)
x, y = center_irc[2], center_irc[1]
input_points = [[x, y]]
predictor = SamPredictor(sam)
predictor.set_image(ct_image_array)
masks, _, _ = predictor.predict(point_coords=np.array(input_points), \
point_labels=np.array([1]), multimask_output=False)
mask_image_array = masks[0]
mask_image = Image.fromarray(mask_image_array)
mask_image.save(mask_filepath)
最后,我们会把元数据保存在一个字典里,其中包括 series UID、IRC 点坐标,以及图像和 mask 的路径。这些元数据会被序列化到一个叫 metadata.jsonl 的文件里。有了这套结构之后,构建我们的 FineTuning 数据集类就变得很简单了:只需要从该文件中读取元数据,再高效地实现 __getitem__ 即可。
生成这些图像和 mask 可能要花上几分钟。为了方便起见,我们已经把数据公开上传到了 Hugging Face,地址是 mng.bz/gmGZ;你可以使用huggingface_hub 库把它下载下来:
if not os.path.exists(fine_tuning_dir):
from huggingface_hub import snapshot_download
repo_id = "H-Huang/LUNA16_segmentation_data"
snapshot_download(repo_id=repo_id, local_dir=fine_tuning_dir,
↪ repo_type="dataset")
现在既然图像和 mask 都有了,我们就可以创建 FineTuning 数据集类。这个类会读取元数据,并在训练时返回图像和 mask 的访问信息:
class FineTuning(dataset):
...
def __getitem__(self, index):
metadata = self.metadata_list[index]
ct_image_path = f"{self.fine_tuning_dir}/{metadata['ct_file_name']}"
mask_image_path = f"{self.fine_tuning_dir}/{metadata['mask_file_name']}"
series_uid = metadata["series_uid"]
center_irc = metadata["center_irc"]
return {
"series_uid": series_uid,
"center_irc": torch.tensor(center_irc),
"ct_image_path": ct_image_path,
"mask_image_path": mask_image_path,
}
fine_tuning_data = FineTuning()
15.7 为微调更新训练流程
现在,我们有模型了,也有数据了。自然而然地,下一步——也就是图 15.14 中的 step 2C——就是用这些数据把新模型训练起来。
图 15.14 本章整体脉络,此处聚焦于训练循环需要做出的修改
更具体一点说,为了完成模型训练,我们需要引入两个新的概念:
- 首先,我们得实例化这个新模型(这当然不奇怪)。
- 其次,我们会换用另一种优化器;仍然选择一个很流行的方案:AdamW。
15.7.1 什么是微调模型
我们的目标,是让模型能够自动分割出感兴趣的结节区域,从而在面对未见过的新图像时,也能自动检测出应当送入前几章分类模型的区域。之前使用 Segment Anything 时,我们必须提供一个点提示,模型才能完成分割。但现在我们希望把这个过程自动化:不再需要额外输入,就能直接对图像进行分割。为此,就必须在我们自己的 CT 切片与 mask 数据集上对模型进行微调。
前面提到过,我们将训练一个不同的模型,叫做 SegFormer。SegFormer 出自 Xie 等人的论文《SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers》(arxiv.org/abs/2105.15…)。它是专门面向语义分割设计的,并且可以针对特定数据集进行微调,以便在特定任务上——例如医学影像或自动驾驶——获得更好的表现。
所谓 fine-tuning(微调) ,是在机器学习中指:拿一个已经在大型通用数据集上训练好的预训练模型,再用一个新的数据集继续训练它,从而让它适应某个更具体的任务或领域。这种做法会利用模型已经从大规模通用数据中学到的知识,再进一步把这些知识“调整”到我们的特定任务上。具体来说,我们从一个已经学会了广泛特征的预训练模型出发,然后继续在自己的数据集上训练它。这样,模型就能在保留原有知识的同时,更好地适配新任务。
由于 SegFormer 足够轻量,因此它可以在单块 GPU 上完成训练,这使得它对于我们当前这些 CT 图像与 mask 来说非常合适。下面来看如何把它落到代码里。
设置模型
首先,我们需要设置数据集和数据加载器。这里会使用上一节创建好的 FineTuningDataset 类来加载图像与 mask;同时,再创建相应的 DataLoader,以便在训练过程中迭代这些数据(5_fine_tuning.ipynb)。
代码清单 15.8 创建训练与验证 DataLoader
from p2ch15.utils import FineTuningDataset
train_dataset = FineTuningDataset(split="train")
val_dataset = FineTuningDataset(split="val")
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8)
我们将使用的模型叫做 SegformerForSemanticSegmentation,它来自 transformers 库。这个模型是为语义分割设计的,可以直接在我们的数据集上做微调。我们选择 nvidia/mit-b0 作为基础模型,它是 SegFormer 的一个轻量版本:
# In[]:
from transformers import SegformerForSemanticSegmentation
id2label = {"0": "background", "1": "nodule"}
label2id = {v: k for k, v in id2label.items()}
model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
model.to(device)
# Out[]:
Some weights of SegformerForSemanticSegmentation were not initialized from↪
↪the model checkpoint at nvidia/mit-b0 and are newly initialized:↪
↪['decode_head.batch_norm.bias', 'decode_head.batch_norm.↪
↪num_batches_tracked', 'decode_head.batch_norm.running_mean',↪
↪ 'decode_head.batch_norm.running_var', ...]
You should probably TRAIN this model on a down-stream task to be able to↪
↪use it for predictions and inference.
你可能注意到了输出中提示:有些权重没有从 checkpoint 中初始化,而是新初始化的。这是正常现象,因为我们正在把模型微调到自己的任务上,这些新参数会在训练过程中被更新。图 15.15 展示了 SegFormer 的整体架构。
图 15.15 SegFormer 的架构。我们这里只微调 decoder 部分。
接下来,我们还需要设置优化器以及学习率调度器:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
15.7.2 使用 AdamW 优化器
我们会使用 AdamW 优化器,它是 Adam 的一个变体,加入了 weight decay。Adam 优化器(arxiv.org/abs/1412.69…)是训练模型时相对于 SGD 的另一种常见选择。Adam 会为每个参数分别维护一个学习率,并在训练过程中自动调整这些学习率。由于它能自动完成这些更新,所以多数情况下并不需要像使用 SGD 那样仔细指定一个非默认学习率;它会很快为每个参数找到一个还算合理的学习率。
代码中实例化 AdamW 的方式如下:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
通常来说,把 Adam 作为大多数项目的起始优化器,被认为是一个合理选择(参见 cs231n.github.io/neural-netw…)。当然,其他优化器配置在某些场景下也可能优于 Adam(例如带 Nesterov momentum 的 SGD),但想真正为某个项目找到合适的超参数,通常既困难又耗时,而且大多数时候会显得过度设计、甚至有点“为调参而调参”。
围绕 Adam 也衍生出了大量变体——AdaMax、RAdam、Ranger 等等,每一种都有自己的优缺点。深入这些细节超出了本书范围,但我们认为,至少知道这些替代方案存在,是很重要的。本章中我们选用 AdamW,一方面因为它正是 SegFormer 原论文中采用的优化器,另一方面我们也会沿用论文中相同的学习率设定。
15.7.3 设计训练循环
模型的输入是 CT 图像和 mask,但模型要求这些图像与 mask 必须被预处理成特定格式。具体来说,对于 SegFormer,transformers 库提供了一个 SegformerImageProcessor 类,专门负责预处理图像和 mask。它会自动完成缩放、归一化,以及把数据转换成模型期望的格式:
image_processor = SegformerImageProcessor()
def encode_inputs_for_model(image_paths, masks_paths=[]):
images = [Image.open(path) for path in image_paths]
masks = [Image.open(path) for path in masks_paths] or None
encoded_inputs = image_processor(images, masks, return_tensors="pt")
return encoded_inputs["pixel_values"].to(device), encoded_inputs[
↪"labels"].to(device)
现在就可以写训练循环了。我们会迭代数据加载器,把图像和 mask 喂给模型,然后计算 loss 并更新模型权重:
num_epochs = 20
for epoch in range(num_epochs):
model.train()
print("Epoch:", epoch)
total_train_loss = 0
num_train_batches = 0
for idx, batch in enumerate(tqdm(train_dataloader)):
pixel_values, labels = encode_inputs_for_model(batch[
↪"ct_image_path"], batch["mask_image_path"])
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_loss += loss.item()
num_train_batches += 1
这样就会对模型训练 20 个 epoch。我们持续累计总 loss 和 batch 数量,以便在每个 epoch 结束时计算平均 loss。之后,我们还会在验证集上评估模型:
model.eval()
total_val_loss = 0
num_val_batches = 0
with torch.no_grad():
for batch in val_dataloader:
pixel_values, labels = encode_inputs_for_model(batch[
↪"ct_image_path"], batch["mask_image_path"])
outputs = model(pixel_values=pixel_values, labels=labels)
val_loss = outputs.loss
total_val_loss += val_loss.item()
num_val_batches += 1
average_train_loss = total_train_loss / num_train_batches
average_val_loss = total_val_loss / num_val_batches
print(f"Training Loss: {average_train_loss:.4f}, Validation Loss: {↪
↪average_val_loss:.4f}")
验证代码与训练循环非常相似,但有两个关键差别:模型会切换到 eval() 模式,同时通过 with torch.no_grad() 禁用梯度计算。这样既不会在验证期间更新参数,也能节省性能和显存。
由于我们已经把 loss 写入了 TensorBoard,因此还可以像图 15.16 那样,把训练过程可视化出来。随着模型训练推进,我们可以看到训练损失和验证损失都在稳定下降,这说明模型很可能正在学会对未见数据进行合理泛化。
图 15.16 在 TensorBoard 中绘制出的训练损失与验证损失曲线
15.7.4 保存模型
PyTorch 让“把模型保存到磁盘”这件事变得很简单。torch.save 在底层使用的是标准 Python pickle 库,因此理论上你可以直接把整个模型实例传进去,它也会被正确保存。不过,这通常不被认为是持久化模型的理想方式,因为那样会失去一些灵活性。
更常见、也更推荐的方式是:只保存模型参数。这样做的好处是,以后你可以把这些参数加载进任何一个“参数 shape 匹配”的模型中,即便这个模型类本身与最初保存时的类不完全一样。只保存参数,会让模型拥有更强的可复用性和可重组性,而不是被整个类定义死死绑住。
我们可以通过 model.state_dict() 来拿到模型参数(5_fine_tuning.ipynb)。
代码清单 15.9 把模型参数保存到磁盘
torch.save(model.state_dict(), "p2ch15/segformer_epoch_20.pt")
model.state_dict() 返回的是一个 Python 字典对象,它把模型中每一层映射到相应的参数张量,里面包括权重和偏置等内容。
提示:你不仅可以保存模型参数。torch.save() 完全可以处理任意字典结构,因此也可以顺手把优化器状态、时间戳、已完成的训练步数等额外信息一起存下来。如果你的算力资源是断断续续可用的,这对于无缝恢复训练、追踪实验进展会非常宝贵。关于如何加载模型与优化器状态并恢复训练,可以参考官方文档(mng.bz/eBqw)。
加载模型和保存一样简单。我们像平常一样先实例化模型,然后把文件中的参数读出来,再加载进去:
model = SegformerForSemanticSegmentation.from_pretrained(
"nvidia/mit-b0",
num_labels=2,
id2label=id2label,
label2id=label2id,
)
state_dict = torch.load("p2ch15/segformer_epoch_20.pt")
model.load_state_dict(state_dict)
model.to(device)
当你调用 torch.load("p2ch15/segformer_epoch_20.pt") 时,实际上就是把先前保存在该文件里的 state dictionary 重新读回内存。最后,model.to(device) 则把模型移动到指定设备(CPU 或 GPU)上,从而让它可以继续参与后续计算。
15.8 推理与结果
现在模型已经训练好了,它就可以被用于对未见过的 CT 扫描做结节分割。只要把 CT 切片输入模型,模型就能生成对应的 mask;随后,我们再把这些 mask 可视化出来,就能直观评估它的效果。
为了演示模型是如何生成 mask 的,我们会从数据集中选取特定点做示范。可以定义一个辅助方法,负责执行推理并返回 mask:
def inference(model, image_path, mask_image_path, show_plot=False):
model.eval()
image = Image.open(image_path)
mask_image = Image.open(mask_image_path)
pixel_values = image_processor(images=image, return_tensors="pt").
↪pixel_values.to(device)
with torch.no_grad():
outputs = model(pixel_values=pixel_values)
predicted_segmentation_map = image_processor.
↪post_process_semantic_segmentation(outputs,
↪ target_sizes=[(512, 512)])[0]
if show_plot:
plot_image_and_masks(image, mask_image,
↪ predicted_segmentation_map.cpu().numpy())
return predicted_segmentation_map
只要把图像路径和 mask 路径传给这个 inference 方法,模型就能生成 mask。我们还可以把图像与 mask 一起可视化出来,看看模型到底分得怎么样(图 15.17)。
图 15.17 CT 图像(左)、ground truth mask(中)与模型预测结果(右)
恭喜!我们的模型已经能够从 CT 扫描中把结节分割出来了。
15.9 结论
在本章中,我们探索了像素级到像素级的分割技术,重点放在 Segment Anything model(SAM)以及它在医学影像任务中的应用。我们介绍了 SAM —— 这是一个面向多用途分割任务的基础型计算机视觉模型,并展示了它如何通过可提示输入生成分割 mask。具体来说,我们演示了如何仅通过输入一张 CT 图像和一个参考点,就让 SAM 为候选结节生成 mask。
在此基础之上,我们接着构建了一个全新的数据集,其中包含图像及其对应的 mask,用作后续解决分割问题的资源。这为实验和分析提供了一个非常实用的框架,使我们能够进一步训练自己的模型。
最后,我们对 SegFormer 模型进行了微调,让它能够在 CT 扫描中分割结节。我们设计了一套训练循环来处理图像与 mask,并在 TensorBoard 中监控训练损失和验证损失;同时,还把模型参数保存下来,以便更灵活地复用。我们也展示了如何把这些参数重新加载回来,用于推理。最终,我们使用训练好的模型在未见过的 CT 扫描上生成并可视化分割 mask,从而验证了它在识别结节方面的有效性。
15.10 练习
对你自己选择的一个新数据集,对 SegFormer 模型进行微调:
- 你在数据集准备流程上做了哪些改动?
- 微调之后,模型表现发生了怎样的变化?
尝试使用不同优化器(例如 AdamW 与 SGD)来训练 SegFormer:
- 不同优化器在训练动态和最终性能上有什么差异?
- 学习率对训练过程有什么影响?
把分割数据集实现改成三路划分:训练集、验证集和测试集:
- 你为测试集保留了多大比例的数据?
- 测试集和验证集上的表现是否彼此一致?
- 为了更好地追踪实验进度,还可以记录哪些额外元数据?
实现一条“先分割、再分类”的结节检测流水线:
- 分割输出会如何影响分类结果?
- 这种做法的优势和挑战分别是什么?
- 除了 LUNA(或 LIDC)之外,你还能找到其他可用数据源吗?
小结
- 分割是在像素或体素级别上判断某个位置是否属于某个类别。这与分类不同,分类是在整张图像层面做判断。
- 分割有多种类型,包括:语义分割(把像素分到不同类别中)、实例分割(给单独对象赋予不同标签)、以及目标检测(用包围框定位对象)。
- Segment Anything model(SAM)代表了基础型分割模型的一次重要进展。它支持利用可提示的点信息来引导分割。
- SAM 由三个主要部分构成:处理输入图像的 image encoder、编码用户提示(如点、框、mask)的 prompt encoder,以及结合图像编码与提示信息来生成最终分割 mask 的轻量级 mask decoder。
- 通过“先分割、后分类”,我们可以在相对温和的数据量和算力要求下实现检测任务。
- 我们可以通过在新数据集上进行微调,把一个模型适配到新的任务上。
- Adam 系列优化器是训练模型时非常流行的一类选择,因为它会自动为每个参数调整学习率。
- SegFormer 是一种专为语义分割设计的 transformer 架构模型。
- 我们可以从 Hugging Face 下载预训练模型和可用数据集。
- TensorBoard 不仅能显示训练过程中生成的二维图像,还会记录这些图像随训练过程变化的历史,因此可以用来直观追踪模型输出的演化。
- 模型参数可以保存到磁盘中,也可以重新加载回来以恢复一个先前保存过的模型。只要旧参数和新模型参数之间存在一一对应关系,模型实现本身甚至可以有所变化。