LOADING

加载过慢请开启缓存 浏览器默认开启

Toward Accurate Cardiac MRI Segmentation With Variational Autoencoder-Based Unsupervised Domain Adaptation 论文复现

论文链接:Toward Accurate Cardiac MRI Segmentation With Variational Autoencoder-Based Unsupervised Domain Adaptation
论文主要解决了心肌分割的问题,提出无监督域适应方法,将bSSFP(源域)的知识迁移到LGE(目标域)中,实现无需目标域标注的高精度分割。
关于论文的前置知识,可见KL散度、ELBO、VAE等博客。

架构

传统VAE即变分自编码器只有Encoder与Decoder两部分,论文中的VAMCEI增加了分割器部分,并且通过若干个损失函数来对齐源域和目标域的特征空间。

根据架构图,源域和目标域图像都通过UNet风格的Encoder进行特征提取(到潜在z空间),z空间通过Decoder进行重建;z空间通过分割器进行预测。根据论文的复现,有7个损失函数:

  1. 源域预测结果与真实掩码的分割损失
  2. 源域和目标域图像的重建损失
  3. 源域和目标域的潜在空间分布分别与标准高斯分布的 KL 散度损失(VAM正则化损失)
  4. 源域的整体潜在分布与目标域的整体潜在分布之间的双向 KL 散度损失(全局特征对齐损失)
  5. 原型对比损失(局部特征对齐损失)
  6. 源域和目标域上生成器与判别器的对抗损失(隐式特征对齐损失)

论文复现代码见:cardiac_uda_vamcei

接下来重点解析论文中的关键数学推导,包括:

  1. 变分自编码器(VAE)的目标函数
  2. 显式全局特征对齐(KL散度推导)
  3. 显式局部特征对齐(原型对比损失)
  4. 隐式特征对齐(对抗损失)
  5. 多阶段框架中的知识蒸馏损失
    以下逐一详细解释:

关键公式与推导

1. VAE基础:变分下界(公式1)

VAE的核心目标是最大化观测数据 (x,y) 的对数似然,通过变分推断转化为可优化的下界:

logpθ(x,y)LBVAE(θ,ϕ)=DKL(qϕ(z|x)pθ(z))+Eqϕ(z|x)[logpθ(x|y,z)]+Eqϕ(z|x)[logpθ(y|z)]

变量说明:

  • x:输入图像(心脏MRI)
  • y:分割标签(LV/RV/Myo)
  • z:潜在变量
  • qϕ(z|x):编码器输出的后验分布(近似真实后验 pθ(z|x)
  • pθ(z):先验分布(标准正态 N(0,I)
  • pθ(x|y,z):解码器重建的图像分布
  • pθ(y|z):分割器预测的标签分布

三项分解:

  • KL散度项:

    DKL(qϕ(z|x)pθ(z))

    强制潜在空间 z 服从标准正态分布(正则化)。
    具体计算(公式2,两正态分布的KL散度有公式):

    DKL=12j=1Mi=1n(σij2+μij2logσij21)

    其中 M 为 batch 大小,n 为潜在空间维度,μij,σij 为第 j 个样本第 i 维的均值和方差。

  • 重建项:

    Eqϕ(z|x)[logpθ(x|y,z)]

    最大化重建图像 x^ 的似然,对应二值交叉熵损失(公式3):

    LR=i=1Mx^ilogxi+(1x^i)log(1xi)
  • 分割项:

    Eqϕ(z|x)[logpθ(y|z)]

    分割预测损失(公式4):

    Lseg=i=1M[LCE(yi,y^i)+LDice(yi,y^i)]

    结合交叉熵和 Dice 损失处理类别不平衡。


2. 显式全局特征对齐(公式5-8)

核心问题: 源域和目标域潜在空间分布不一致,导致域偏移。

解决方案: 最小化两域潜在分布的 KL 散度。

  • 双向 KL 散度(公式5):

    D[qϕs(z),qϕt(z)]=DKL[qϕs(z)qϕt(z)]+DKL[qϕt(z)qϕs(z)]

    传统方法用 L2 距离,本文创新性地采用对称 KL 散度更准确度量分布差异。

  • 小批量近似(公式6):

    DKL[qϕs(z)qϕt(z)]=[1Mi=1Mqϕs(z|xSi)]ln1Mqϕs(z|xSi)1Mqϕt(z|xTi)dz
  • 高斯近似(公式7):

    DKL1M2i=1Mj=1MEqϕs(z|xSj)[lnqϕs(z|xSj)lnqϕt(z|xTj)]
  • 独立维度分解(公式8):

    DKL=1M2k=1nj=1Mi=1M[lnσTikσSik12+σSjk2+(μSjkμSik)22σTik2+σSjk2+(μSjkμTik)22σSik2]

    其中 μSik,σSik 为源域第 i 个样本第 k 维的均值和方差,μTik,σTik 为目标域对应值。
    关键在于将复杂的多维积分转化为可计算的求和。


3. 显式局部特征对齐(公式9-10)

目标: 对齐同类特征,分离异类特征(跨域)。
举例来说,是为了对齐源域和目标域中心肌 Myo 的特征,分离源域心肌 Myo 与目标域右心室 RV 的特征。

  • 类别原型计算(公式9):

    Ck=i=1Mj=1nzijI(yij=k)i=1Mj=1nI(yij=k)

    其中 zij 为第 i 个样本第 j 像素的特征向量,I(yij=k) 为指示函数(像素属于类别 k 时为 1)。

  • 原型对比损失(公式10):

    Pro(qs,qT)=1Kk=1Kln[exp(CSk,CTk/τ)ikexp(CSk,CTi/τ)+exp(CTk,CTi/τ)]

    其中 , 为余弦相似度,τ 为温度系数。


4. 隐式特征对齐(公式11-12)

通过输出空间域判别器实现。

  • 判别器损失(公式11):

    Ldisd=ExSXS[log(Dis(PS))]+ExTXT[log(1Dis(PT))]

    目标:区分源域/目标域分割图 PS,PT

  • 生成器(编码器)损失(公式12):

    Ldisg=ExTXT[log(Dis(PT))]

    编码器试图"欺骗"判别器,使目标域分割图 PT 被误判为源域,实现隐式特征对齐。


5. 多阶段框架蒸馏(公式16)

目标: 融合互补模型知识,避免语义错误。

  • 知识蒸馏损失(公式16):

    Ldistill=i=1Kexp(pi/T)jexp(pj/T)log(exp(qi/T)jexp(qj/T))

    其中:

    • pi:教师模型平均概率(Target VAMCEI + Source VAMCEI)
    • qi:学生模型预测概率
    • T:蒸馏温度(软化概率分布)
    • K:类别数

物理意义: 最小化学生与教师输出的 KL 散度,传递"暗知识"(dark knowledge)。