3D医学图像重构

194 阅读4分钟

数据介绍

MRI

MRI, (Magnetic Resonance Imaging,磁共振成像), 是一种医学成像技术,对静磁场中的人施加无线电波脉冲,使人体中的氢原子核发生核磁共振(即氢原子的自旋特性),实现不切开身体就可以获取内部器官和组织的详细图像。如图1是核磁共振成像设备,图2是人脑核磁共振成像结果。

核磁共振成像设备.png

图1 核磁共振成像设备

核磁共振成像结果.png

图2 人脑核磁共振成像

DMRI

DMRI, (Diffusion Magnetic Resonance Imaging, 扩散磁共振成像), 是一种特殊的MRI技术,用于测量和成像水分子在组织中的扩散过程,特别适用于观察大脑中的白质束 (由大脑中神经纤维组成,用于传递大脑各部分之间的信息)。

NODDI

NODDI, (Neurite Orientation Dispersion and Density Imaging, 神经纤维方向散布与密度成像), 是一种高级的DMRI技术,用于描述大脑中神经纤维的方向和密度,可以提供更细节的信息相比DMRI。

Mango

Mango, (Multi-image analysis GUI, 多图像分析图像用户界面), 是一个处理医学图像的软件,可以查看、分析和标记医学图像。Mango官网 (下面参考列表有) 下载 Mango软件后,我们分别打开对人脑的DMRI图像和NODDI图像,如图3所示。一张图有3个小图,分别对应的上下扫、前后扫和左右扫的结果。如果我不告诉你哪一个是DMRI,哪一个是NODDI结果,你能自己推断出来嘛?

dmri和noddi结果.png

图3 人脑成像图

首先,图中大脑里像水草一样的物质就是神经纤维,可以看到它是充满我们大脑的,这东西损伤又可能会造成不可逆的伤害,因此,一般不打别人脑壳的。其次,NODDI是DMRI的进化版,可以观察到神经纤维的密度密度在图上表现出来就是颜色的亮暗。综上,第一张是NODDI成像的结果,第二张是DMRI成像的结果。

注意,dmri图像和noddi图像有采集方向的概念,一般用mango打开的话,一个slice就是一个方向。我现在的理解是为了测量水分子的扩散方向,于是从不同角度(通道)进行观察,就跟彩色图像的RGB通道一样,从不同视角观测同一个事物(可能比较肤浅,欢迎指正)。

不同方向的dmri结果.png

图4 不同方向的DMRI结果

模型介绍

VQ-GAN

我之前写过关于VQ-GAN的博客,见参考3. 如图4所示,VQ-GAN就是GAN里面加了个Codebook。为什么用GAN是因为GAN生成的图像质量高,质量高是因为判别器会提升生成器生成图片的质量。锦上添花加个Codebook是为了特征的量化,量化可以提高生成图像的质量(连续变量可能映射到不匹的图像块),压缩特征空间,稳定训练过程(连续变量要建立的映射理论上是数不尽的,而离散变量显然好学习)。

VQ-GAN.png

图5 VQ-GAN的架构

需要注意的是,VQGAN的损失中有一个自适应权重λ,

Q=argminE,G,ZmaxDExp(x)[LVQ(E,G,Z)+λLGAN({E,G,Z},D)],λ=GL[Lrec]GL[LGAN]+δ.\mathcal{Q^*}=argmin_{E,G,Z}max_{D}E_{x\sim p(x)}[\mathcal{L}_{VQ}(E,G,Z)+\lambda \mathcal{L}_{GAN}(\{E,G,Z\},D)],\\ \lambda=\frac{\nabla_{G_{L}}[\mathcal{L_{rec}}]}{\nabla_{G_{L}}[\mathcal{L}_{GAN}]+\delta}.

之前理解是用来平衡不同损失之间的梯度大小,让训练更加稳定。我们也可以从如下角度去考虑,

  • 在更新判别器时,LVQ\mathcal{L}_{VQ}当作常数,λ\lambda没有什么作用。

  • 在更新生成器时,

    • 若重构效果差(好),辨别器容易(不容易)区分图片,Lrec\mathcal{L}_{rec}大(小),LGAN\mathcal{L}_{GAN}大(小),λ\lambda接近1,重构与判别都顾及。
    • 若重构效果差,判别器不容易区分图片,Lrec\mathcal{L}_{rec}大,LGAN\mathcal{L}_{GAN}小,λ\lambda大,更新生成器也会平衡下判别器的损失,从而两者都顾及。
    • 若重构效果好,判别器容易区分图片,Lrec\mathcal{L}_{rec}小,LGAN\mathcal{L}_{GAN}大,λ\lambda小,更新生成器也会平衡下重构上的损失,从而两者都顾及。

综上,λ\lambda可以稳定训练的过程。

图5是用VQGAN还原的手写数字,可以看到图片质量很高,图6是Ground Trues。

VQ-GAN还原的手写数字.png

图6 VQ-GAN生成的手写数字

VQ-GAN生成手写数字对应的GD.png

图7 手写数字对应的Ground Trues

注意,我们实现的时候往往看到会晚启动判别器,这是为什么呢? 后面实现的时候,我们注意到如果不晚启动判别器,这个λ\lambda往往会比较大(比较容易判别),从而导致损失为很大的负数原因就是那时候模型刚开始学习,自然容易判别出来真伪。因此,我们往往会晚启动判别器。

DDPM

之前博客也提到DDPM,(Denoising Diffusion Probabilistic Models)。扩散模型最初的灵感来源于物理学中的扩散过程。我们可以将其理解为一种逐渐给数据“加噪”的过程。假设我们有一张干净的图片,接下来我们不断往这张图片里加噪音,直到它变成纯粹的噪声。扩散模型的目标就是反向推理出从纯噪声到原始图片的路径。相比GAN的话,它无需复杂的训练过程,也可以生成高质量的内容。

前向扩散过程

首先,我们定义前向扩散过程(Markov Chain),逐渐将原始数据 x0x_0 转变为噪声数据 xTx_T。这个过程被建模为一个Markov链,每一步都通过添加高斯噪声来逐渐“污染”数据。公式如下:

q(xtxt1)=N(xt;1βtxt1,βtI)q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} x_{t-1}, \beta_t \mathbf{I})

其中,βt\beta_t 是一个预定义的噪声调度参数,决定了每一步添加的噪声量。 经过 TT 步之后,我们就得到了一个近乎纯噪声的分布 xTx_T,此时数据的信息几乎完全丧失。

反向扩散过程

现在问题来了,如何从纯噪声恢复出原始数据呢?这就需要反向扩散过程。根据贝叶斯定理,反向过程可以表示为:

pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t))p_\theta(x_{t-1} | x_t) = \mathcal{N}(x_{t-1}; \mu_\theta(x_t, t), \Sigma_\theta(x_t, t))

模型的任务就是学会这个均值 μθ\mu_\theta 和方差 Σθ\Sigma_\theta 的估计。

损失函数

简洁却强大 DDPM的核心是通过一种基于变分下界(Variational Lower Bound, VLB)的损失函数,来优化反向过程的参数。具体来说,我们希望最小化以下损失函数:

Lvlb=Eq[t=1TKL(q(xt1xt,x0)pθ(xt1xt))logpθ(x0x1)]L_{vlb} = \mathbb{E}_q \left[ \sum_{t=1}^T \text{KL}(q(x_{t-1}|x_t, x_0) || p_\theta(x_{t-1}|x_t)) - \log p_\theta(x_0|x_1) \right]

然而,为了简化训练,通常我们会将这个复杂的KL散度替换为一个更易于优化的形式,最终的目标变成了:

Lsimple=Et,x0,ϵ[ϵϵθ(xt,t)2]L_{simple} = \mathbb{E}_{t, x_0, \epsilon} \left[\| \epsilon - \epsilon_\theta(x_t, t) \|^2\right]

其中,ϵ\epsilon 是高斯噪声, ϵθ(xt,t)\epsilon_\theta(x_t, t) 是神经网络学到的噪声估计。

DDIM

虽然DDPM在生成质量上表现出色,但其逐步采样的过程往往需要几百到上千步,这使得生成过程非常耗时。为了解决这个问题,DDIM应运而生。

非马尔可夫链的引入

DDIM的核心思想是通过一种非马尔可夫链的设计,实现了对采样过程的加速。与DDPM不同,DDIM在反向过程不再依赖于条件概率,而是直接建模每一步的确定性映射:

xt1=αt1(xt1αtϵθ(xt,t)αt)+1αt1ϵθ(xt,t)x_{t-1} = \sqrt{\alpha_{t-1}} \left(\frac{x_t - \sqrt{1-\alpha_t} \epsilon_\theta(x_t, t)}{\sqrt{\alpha_t}} \right) + \sqrt{1-\alpha_{t-1}} \cdot \epsilon_\theta(x_t, t)

这里的 αt\alpha_t 是扩散过程中的累积噪声参数。

采用DDIM加速生成并不需要修改训练过程,只要采用DDIM跳步方式sample,具体代码如下

# sample_steps 是自己设定的跳步数
self.sample_timesteps = torch.linspace(timesteps- 1, 0, sample_steps).long().to(device)
...
@torch.no_grad()
def ddim_sample(self, xt, condition_tensors= None, eta= 0):
    for t, tau in list(zip(self.sample_timesteps[:-1], self.sample_timesteps[1:]))[::-1]:
        sqrt_alphas_cumprod_tau= self.alphas_cumprod[tau]** 0.5
        sqrt_alphas_cumprod_t= self.alphas_cumprod[t]** 0.5
        sqrt_one_minus_alphas_cumprod_t= (1.0- self.alphas_cumprod[t])** 0.5
        t_= torch.ones((xt.shape[0], ), device= xt.device, dtype= torch.long)* t
        pred_noise= self.denoise_fn(xt, t_, lowres_cond_img= condition_tensors)
        # + sqrt_one_minus_alphas_cumprod_t* self.post_hoc_guiding_mechism(xt, condition_tensors, F.l1_loss)
        sigma= (eta* torch.sqrt(extract(self.betas, t_, xt.shape)))
        fist_term= sqrt_alphas_cumprod_tau* (xt- sqrt_one_minus_alphas_cumprod_t* pred_noise)/ sqrt_alphas_cumprod_t
        sec_term= ((1.0- self.alphas_cumprod[tau]- sigma** 2)** 0.5)* pred_noise
        eps= torch.randn_like(xt)
        xt= fist_term+ sec_term+ sigma* eps
        return xt

任务

任务描述

给定三维的低分DMRI影像,要求得到对应的高分NODDI影像。

评估指标

PSNR

PSNR, (peak singal noise ratio, 峰值信噪比), 所谓峰值就是像素最大值信噪反映了噪声图片的噪声大小。MSE越大,则信号噪声越大,则PSNR(I, K)如

PSNR(I,K)=10log10(MAXI2/MSE(I,K)).PSNR(I, K)= 10 \cdot log_{10}{({MAXI^2}/{MSE(I,K)})}.

其中,MAXI是原图I的最大像素值,MSE(I,K)是原图I和带噪声图K的差异,反映了信噪大小。

PSNR越小,则噪声越大,图像丢失高频信息越多,反之噪声越小,高频信息越多,图像越清晰

缺点

PSNR基于MSE,然而MSE本身并不能很好的如人眼的感受来衡量两张图象的相似度,如图8所示,

SSIM论文反例图.png

图8 船图

图a与其他图的PSNR都很接近,这是否意味着其他图之间的差异就比较小呢?我们可以看出有些图之间存在人眼可见的差异,如图8中的f和e,f对细节描述比e清楚多了,于是引入了SSIM指标。

SSIM

SSIM, (structural similarity, 结构相似性), 它关注于图像的结构信息,包括了亮度对比度结构特征的相似性。具体计算上,用像素均值反映亮度大小,标准差反映对比度强弱,协方差反映结构一致与否,故SSIM(I, K)计算如下,

SSIM(I,K)=2μxμy+C1μx2+μy2+C12σxσy+C2σx2+σy2+C2σxy+C3σxσy+C3.SSIM(I,K)=\frac{2\mu_{x}\mu_{y}+C_{1}}{\mu_x^{2}+\mu_y^{2}+C_{1}}\cdot\frac{2\sigma_{x}\sigma_{y}+C_{2}}{\sigma_x^{2}+\sigma_y^{2}+C_{2}}\cdot\frac{\sigma_{xy}+C_{3}}{\sigma_x\sigma_y+C_{3}}.

如上式所示,SSIM是三个分式的连乘积,每个分式的值域是[0,1][0,1],分别对应亮度相似性、对比度相似性和结构相似性。对于 2xy+cx2+y2+c\frac{2xy+c}{x^2+y^2+c},x与y越接近,函数值越接近最大值1,加个c防止分母为0. 相应函数图,如图9所示。

2xy比x2+y2图.png

图9 分式对应函数可视化结果

注意,psnr和ssim都被python的scikit-image库所包含,且ssim中的结构相似性是采用了滑动窗口实现的,从而可以很好的捕捉图像的局部信息,减少全局噪声的影响,但同时有选择窗口大小,计算量变大的问题。

具体实现

我们基于3d diffusion modelVQGAN,引入低分DMRI条件,设计了我们的模型,如图10所示。

模型框架.png

图10 模型框架图

  • 训练时 高分NODDI的隐特征进行加噪,拼接低分DMRI隐特征,进行去噪。

  • 测试时 低分DMRI隐特征拼接噪声,通过1000次(为了加速推理,我们采用了跳步为10的DDIM)迭代去噪,得到高分DONNI的隐特征。高分NODDI的隐特征经过量化,VQGAN的解码就得到了高分NODDI。

实验结果

实验结果.png

图11 生成高分NOODI结果

如图11所示,相比低分NODDI,我们生成的结果显然具有更高的质量,其PSNR25.39SSIM达到了0.931.

补充

白噪声的判别

根据白噪声的定义,其均值为0,方差为1,非自相关。均值与方差好判别,如何判别非自相关呢?它是一种用来测量同一序列中不同时间点之间的相似性的统计量。相似性的话,内积就可以判断;不同时间点,只要设置个时间延迟。

因此,自相关的定义为

γ(k)=E[(Xtμ)(Xt+kμ)]σ2.\gamma(k)=\frac{E[(X_t-\mu)(X_{t+k}-\mu)]}{\sigma^2}.

代码实现

均值,方差好计算,序列与不同时延序列的相似性可以用1d卷积实现,因此简单的代码实现如下,

def check_white_noise(x, threshold= 23e-3):
    mean_= x.mean()
    var_= x.var()
    x= x.view(1, 1, -1)
    padded_x= torch.nn.functional.pad(x, (x.shape[2]- 1, 0))
    result= torch.conv1d(padded_x- mean_, x- mean_, padding= 0).view(-1)/ var_/ x.shape[2]
    result= result[:-1]
    print(f'mean: {mean_}, var: {var_}, autocorrelation: {(int) ((result> threshold).sum()> 0)}')

代码

具体代码可见我的github.

参考

[1] 核磁共振成像原理

[2] Mango网址

[3] VQ-GAN

[4] DDPM

[5] DDIM