Diff-Instruct:指导任意生成模型训练的通用框架,无需额外训练数据即可提升生成质量

74 阅读4分钟

❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦

原文链接:mp.weixin.qq.com/s/faeBUXbDs…


🚀 快速阅读

  1. 功能:Diff-Instruct 能从预训练扩散模型中提取知识,指导其他生成模型的训练。
  2. 原理:基于积分Kullback-Leibler散度,通过计算扩散过程中的KL散度积分来比较分布。
  3. 应用:适用于预训练扩散模型的蒸馏和改进现有GAN模型,提升生成性能。

正文(附运行示例)

Diff-Instruct 是什么

公众号: 蚝油菜花 - diff_instruct

Diff-Instruct 是一种先进的知识转移方法,专门用于从预训练的扩散模型中提取知识,并指导其他生成模型的训练。它基于一种新的散度度量——积分Kullback-Leibler (IKL) 散度,通过计算沿扩散过程的KL散度积分来比较分布。这种方法能够在不需要额外数据的情况下,通过最小化IKL散度,实现对任意生成模型的训练指导。

Diff-Instruct 的通用性和有效性在学术界受到广泛关注。它不仅可以显著提升生成模型的性能,还能在多种应用场景中发挥作用,如预训练扩散模型的蒸馏和改进现有的GAN模型。

Diff-Instruct 的主要功能

  • 知识转移:Diff-Instruct 能够从预训练的扩散模型中提取知识,并将其转移到其他生成模型中,无需额外数据。
  • 指导生成模型训练:作为一个通用框架,Diff-Instruct 可以指导任意生成模型的训练,只要生成的样本对模型参数是可微分的。
  • 最小化新型散度:Diff-Instruct 通过最小化积分Kullback-Leibler (IKL) 散度来实现知识转移,这种散度专为扩散模型设计,具有更高的鲁棒性。
  • 提升生成模型性能:Diff-Instruct 在多个实验中展示了其有效性,能够显著提升生成模型的性能,特别是在单步扩散模型和GAN模型的改进上。

Diff-Instruct 的技术原理

  • 通用框架:Diff-Instruct 提出了一个通用框架,可以指导任意生成模型的训练,只要生成的样本对模型参数是可微分的。
  • 积分Kullback-Leibler (IKL) 散度:Diff-Instruct 基于IKL散度,通过计算沿扩散过程的KL散度积分来比较分布,这种散度在比较具有不对齐支持的分布时更具鲁棒性。
  • 数据自由学习:Diff-Instruct 支持使用预训练的扩散模型作为教师来指导各种生成模型,无需额外数据。
  • 灵活性:Diff-Instruct 为生成器提供了非常高的灵活性,生成器可以是基于卷积神经网络(CNN)或基于Transformer的图像生成器,如StyleGAN,或者是从预训练扩散模型适应的基于UNet的生成器。

如何运行 Diff-Instruct

首先,克隆 Diff-Instruct 的 GitHub 仓库并设置 conda 环境:

git clone https://github.com/pkulwj1994/diff_instruct.git
cd diff_instruct

source activate
conda create -n di_v100 python=3.8
conda activate di_v100
pip install torch==1.12.1 torchvision==0.13.1 tqdm click psutil scipy

接下来,准备数据集并运行蒸馏过程。例如,对于 CIFAR-10 数据集的无条件生成,可以使用以下命令:

CUDA_VISIBLE_DEVICES=0 torchrun --standalone --nproc_per_node=1 --master_port=25678 di_train.py --outdir=/logs/di/ci10-uncond --data=/data/datasets/cifar10-32x32.zip --arch=ddpmpp --batch 128 --edm_model cifar10-uncond --cond=0 --metrics fid50k_full --tick 10 --snap 50 --lr 0.00001 --glr 0.00001 --init_sigma 1.0 --fp16=0 --lr_warmup_kimg -1 --ls 1.0 --sgls 1.0

在实验中,FID 值将自动计算并在每个“snap”轮次中显示。

资源


❤️ 如果你也关注 AI 的发展现状,且对 AI 应用开发非常感兴趣,我会每日跟你分享最新的 AI 资讯和开源应用,也会不定期分享自己的想法和开源实例,欢迎关注我哦!

🥦 微信公众号|搜一搜:蚝油菜花 🥦