无需数据的模型压缩 - yubo105139/paper GitHub Wiki
[TOC]
来源:
https://mp.weixin.qq.com/s/cEbhLEsupd_r65iuaC1gZA https://mp.weixin.qq.com/s/1qvbUsdutizHHQMOci-6CA
涵盖四篇论文:
1 无需训练数据的网络压缩技术 DAFL (ICCV 2019)
**论文名称:**Data-Free Learning of Student Networks
1.1 DAFL 原理分析:
华为诺亚方舟实验室联合北京大学和悉尼大学提出了在无数据情况下的网络蒸馏方法 DAFL,比之前的最好算法在MNIST上提升了6个百分点,并且使用 resnet18 在 CIFAR-10 和 100 上分别达到了 92% 和 74% 的准确率 (无需训练数据)。
它的特点是:
- 待压缩网络看作一个固定的判别器D。
- 用生成器G输出的生成图片代替训练数据集进行训练。
- 设计了一系列的损失函数来训练生成器G。
- 使用生成数据结合蒸馏算法得到压缩后的网络。
主要步骤是:
- 通过待压缩网络训练生成器G
- 通过生成器G输出生成图片作为训练样本
- 通过训练样本蒸馏待压缩网络D得到压缩后的网络
1.11 知识蒸馏方法获得学生网络
蒸馏算法最早由Hinton提出,待压缩网络 (教师网络) 为一个具有高准确率但参数很多的神经网络,初始化一个参数较少的学生网络,通过让学生网络的输出和教师网络相同,学生网络的准确率在教师的指导下得到提高。通过已知的就是待压缩的大网络的输入和输出接口得到最终的学生网络
从训练数据的角度看,DAFL 的训练样本是由生成器G生成的,是没有标签的,所以作者引入了教师学生网络学习范式,利用蒸馏算法实现利用未标注生成样本对黑盒网络的压缩。
令$N_T
$和$N_S
$分别代表教师和学生网络,则作者使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:
$\mathcal{L}_{K D}=\frac{1}{n} \sum_{i} \mathcal{H}_{\text {cross }}\left(\mathrm{y}_{S}^{i}, \mathrm{y}_{T}^{i}\right)
$
式中,$H_{cross}
$指交叉熵损失函数,$y_T^i=N_T(x^i)
$和$y_S^i=N_S(x^i)
$分别是教师和学生网络的输出。通过引入教师学生算法,作者解决了生成图片没有标签的问题,并且可以在待压缩网络结构未知的情况下对其进行压缩。
1.12 通过 GAN 生成无标注的训练图片
通过 GAN 来输出一些无标注的训练图片,以便于神经网络的压缩。生成网络希望输出和真实数据类似的图片来骗过判别器,判别网络通过判别生成图片和真实图片的真伪来帮助生成网络训练。
把待压缩网络作为一个固定的判别器D,以此来训练我们的生成网络G。
首先,待压缩网络作为一个固定的判别器D,我们就认为它是已经训练好参数的判别器$D^*
$,我们利用它来训练生成器的基本思想是下式:
$G^{*}=\arg \min _{G} \mathbb{E}_{z \sim p_{z}(z)}\left[\log \left(1-D^{*}(G(z))\right)\right]
$
式中,$D^*
$就是已经训练好参数的判别器,生成器G的参数经过上式持续优化使得$D^*(G(z))
$逐渐上升,代表着生成器的输出越来越能够骗过判别器。
但是,在传统GAN中,传统的判别器D的输出是判定图片是否真假 (Real or Fake?),只要让生成网络生成在判别器中分类为真的图片即可训练,但是,我们的待压缩网络为分类网络,其输出是分类结果 (1-num_classes),所以待压缩网络无法直接作为一个固定的判别器 。因此需要重新设计生成网络的目标。通过观察真实图片在分类网络的响应,作者提出了以下损失函数。
1) 伪标签交叉熵损失
在图像分类任务中,神经网络的训练采用的是交叉熵损失函数,在训练完成后,真实图片在网络中的输出将会是一个one-hot的向量,即分类类别对应的输出为1,其他的输出为0。于是,我们希望生成图片也具有类似的性质。给定一组任意的噪声向量${z^1,z^2,...,z^n}
$,它们通过生成器G之后得到的生成图片是${x^1,x^2,...,x^n}
$,这里$G(z^i)=x^i
$。
现在把这些生成图片${x^1,x^2,...,x^n}
$输入给待压缩的网络,通过$y_T^i=N_T(x^i)
$得到输出${y_T^1,y_T^2,...,y_T^n}
$ ,预测标签就是通过$\mathrm{t}^{i}=\arg \max _{j}\left(\mathrm{y}_{T}^{i}\right)_{j}
$计算得到 。定义伪标签交叉熵损失为:
$\mathcal{L}_{o h}=\frac{1}{n} \sum_{i} \mathcal{H}_{\text {cross }}\left(\mathrm{y}_{T}^{i}, \mathrm{t}^{i}\right)
$
其中$\mathcal{H}_{cross}
$就是标准的交叉熵函数,由于生成图片并没有一个真实的标签,我们直接将其输出最大值对应的标签设定为它的伪标签。伪标签交叉熵损失的意思就是对于一张生成的图片,它的标签就按照教师网络的输出来决定,这是训练生成器G的第1个损失。
2) 特征激活损失函数
在神经网络的训练中,由卷积核提取的特征也是输入图片的一种重要表示。先前的许多工作表明,卷积核提取的特征包含着图片的许多重要信息,将训练数据输入训练好的深度网络中,卷积核会产生更大的响应 (相比于噪声或与此网络无关的数据),基于此,作者提出了特征激活损失函数。定义生成图片$X^i
$经过教师网络得到的特征是$f_T^i
$,则特征激活损失函数定义为:
$\mathcal{L}_{a}=-\frac{1}{n} \sum_{i}\left\|f_{T}^{i}\right\|_{1}
$
反向传播优化生成器参数的方法是:
$\begin{gathered} \frac{\partial \mathcal{L}_{a}}{\partial f_{T}^{i}}=-\frac{1}{n} \operatorname{sgn}\left(f_{T}^{i}\right) \\ \frac{\partial \mathcal{L}_{a}}{\partial W_{G}}=\sum_{i} \frac{\partial \mathcal{L}_{a}}{\partial f_{T}^{i}} \cdot \frac{\partial f_{T}^{i}}{\partial W_{G}} \end{gathered}
$
因为待压缩网络 (即教师网络) 是训练好的,所以目标是让生成图像在待压缩网络中的特征响应值更大,来使图片更接近训练数据。这里作者采用了1范数来优化,原因是1范数相比于2范数会产生更加稀疏的值,而神经网络的响应也常常是稀疏的。
3) 信息熵损失函数
为了让神经网络更好的训练,真实的训练数据对于每个类别的样本数目通常都保持一致,例如MNIST每个类别都含有 6000 张图片。于是,为了让生成网络产生各个类别样本的概率基本相同,作者引入信息熵,信息熵是针对一个概率分布而言的。假设现在有概率分布$p=(\frac{1}{k},\frac{1}{k},...,\frac{1}{k})
$,概率分布p的信息熵的计算方法就是:
$\mathcal{H}_{info}(\mathrm{p})=-\frac{1}{k} \sum_{i} p_{i} \log \left(p_{i}\right)
$
概率分布p越均匀,信息熵$\mathcal{H}_{info}(p)
$就越小。极限情况当$p=(\frac{1}{k},\frac{1}{k},...,\frac{1}{k})
$时,信息熵$\mathcal{H}_{info}(p)
$取极大值$\frac{log(k)}{k}
$。所以信息熵损失函数定义为:
$\mathcal{L}_{i e}=-\mathcal{H}_{i n f o}\left(\frac{1}{n} \sum_{i} \mathbf{y}_{T}^{i}\right)
$
其中$\mathcal{H}_{info}(p)
$为标准的信息熵,信息熵的值越大,对于生成的一组样本经过待压缩教师网络的输出特征${y_T^1,y_T^2,...,y_T^n}
$,$y_T^i=\mathcal{N}_T(x^i)
$来讲,每个类别的数目就越平均,从而保证了生成样本的类别平均。
反向传播优化生成器参数的方法是:
$`\begin{gathered}
\frac{\partial \mathcal{L}{i e}}{\partial y{T}^{i}}=\frac{1}{n} y^{i}\left[\log \left(\frac{1}{n} \sum_{j} y_{T}^{j}\right)+1\right] \
\frac{\partial \mathcal{L}{i e}}{\partial W{G}}=\sum_{i} \frac{\partial \mathcal{L}{i e}}{\partial y{T}^{i}} \cdot \frac{\partial y_{T}^{i}}{\partial W_{G}}
\end{gathered}`$
最后,我们将这三个损失函数组合起来,就可以得到我们生成器总的损失函数:
$\mathcal{L}_{Total}=\mathcal{L}_{oh}+α\mathcal{L}_{a}+β\mathcal{L}_{ie}
$
通过优化以上的损失函数,训练得到的生成器可以和真实的样本在待压缩网络具有类似的响应,从而更接近真实样本,且生成的数据的分布十分均匀。
DAFL 的流程和算法如下图1和图2所示。把待压缩网络当做判别器 ,通过上式 12 作为损失函数来训练生成器 。通过生成器 来得到足够的生成图片,这些图片的分布与训练教师网络的训练数据是一致的。然后,再通过上式 1 的蒸馏损失和这些生成图片对教师网络进行蒸馏得到学生网络。
图1:DAFL 框架
图2:DAFL 算法
1.13 实验结果
作者在MNIST、CIFAR、CelebA三个数据集上分别进行了实验。
MNIST 实验
MNIST 数据集: 10类,60000 training+10000 testing。
作者实验了卷积模型和全连接模型,卷积模型使用 LeNet-5。全连接模型使用 Hinton 提出的具有3个全连接层的网络 Hinton-784-1200-1200-10 作为待压缩模型,将他们的通道数目减半分别作为学生模型 (LeNet-5-HALF 和 Hinton-784-800-800-10)。
图3的前三行是在原始数据集的实验结果。我们以 LeNet-5 模型为例。
- 使用原始数据集,教师网络可以达到 98.91% 的精度。
- 使用原始数据集,学生网络可以达到 98.65% 的精度。
- 使用原始数据集,学生网络+蒸馏方法可以达到 98.91% 的精度,无损。
- 不使用任何数据集,只使用噪声图片+蒸馏方法只能达到 88.01% 的精度。
- 使用另一个替代数据集USUP,学生网络+蒸馏方法可以达到 94.56% 的精度。
- 不使用任何数据集,使用之前的一个基于元数据的方法可以达到 92.47% 的精度。
- 不使用任何数据集,使用 DAFL 方法可以达到 98.20% 的精度。大大超越了之前的方法,并且比使用替代数据集得到的结果也要好很多,和使用原始数据得到的结果基本相似。
图3:MNIST 数据集实验结果
CIFAR 实验
CIFAR-10 数据集: 10类,50000 training+10000 testing。
CIFAR-100 数据集: 100类,50000 training+10000 testing。
作者还在 CIFAR-10 和 CIFAR-100 数据集上进行了实验,使用的教师和学生模型分别为 Resnet-34 和 Resnet-18。
图3的前三行是在原始数据集的实验结果。我们以 CIFAR-10 数据集的结果为例。
- 使用原始数据集,教师网络可以达到 95.58% 的精度。
- 使用原始数据集,学生网络可以达到 93.92% 的精度。
- 使用原始数据集,学生网络+蒸馏方法可以达到 94.34% 的精度,轻微有损。
- 不使用任何数据集,只使用噪声图片+蒸馏方法只能达到 14.89% 的精度,相当于训练失败。
- 使用 CIFAR-10 的数据作为 CIFAR-100 的替代训练集,使用CIFAR-100 的数据作为 CIFAR-10 的替代训练集,虽然 CIFAR-10 和 CIFAR-100 非常相似,并且具有一些重叠的图片,然而,得到的结果距离使用原始数据集仍然有较大的差距,学生网络+蒸馏方法可以达到 90.65% 的精度,有损。证明了在实际情况中使用相似的数据集来替代原始数据集并不能取得很好效果。
- 不使用任何数据集,使用 DAFL 方法可以达到 92.22% 的精度。本论文提出的方法同样取得了和使用原始数据集的蒸馏算法相似的结果,并且超越了使用替代数据集的结果。
图4:CIFAR 数据集实验结果
CelebA 实验
CelebA 数据集:202599 training images
作者又在 CelebA 数据集上进行了实验,使用的教师和学生模型分别为 AlexNet 和 AlexNet-Half。GAN 模型取 DCGAN。
- 使用原始数据集,教师网络可以达到 81.59% 的精度。
- 使用原始数据集,学生网络可以达到 80.82% 的精度。
- 使用原始数据集,学生网络+蒸馏方法可以达到 81.35% 的精度,轻微有损。
- 不使用任何数据集,使用之前的一个基于元数据的方法可以达到 77.56% 的精度。
- 不使用任何数据集,使用 DAFL 方法可以达到 80.03% 的精度,同样取得了很好的结果。
图5:CelebA 数据集实验结果
对比实验
由于我们的方法由很多损失函数组成,我们通过消融实验来分析每个损失函数项的必要性。对比试验的数据集是 MNIST,教师网络是 LeNet-5,学生网络是 LeNet-5-HALF。
下图6是消融实验的结果,一个三个损失函数:伪标签交叉熵损失,特征激活损失函数,信息熵损失函数。可以看到,如果一个都不用,就相当于是直接使用噪声蒸馏学生网络,则准确率是88.01%。使用不同的损失函数,精度如图,每一项损失都很重要。
图6:消融实验的结果
小结
DAFL 是一个新的无需训练数据的网络压缩方法,它的特点是: 待压缩网络看作一个固定的判别器 ,用生成器 输出的生成图片代替训练数据集进行训练,设计了伪标签交叉熵损失,特征激活损失函数,信息熵损失函数来训练生成器 ,使用生成数据结合蒸馏算法得到压缩后的网络。
2 适合大数据集的无需训练数据的网络压缩技术 (Arxiv 2021)
论文名:Large-Scale Generative Data-Free Distillation
论文地址:
介绍
本文的方法框架如下图9所示。本质上和 DAFL 的两个阶段是一致的,都是先用生成器G输出的生成图片代替训练数据集进行训练,然后使用生成数据结合蒸馏算法得到压缩后的网络。
无数据蒸馏方法框架
2.11 知识蒸馏方法获得学生网络
作者使用交叉熵损失来使得学生网络的输出符合教师网络的输出,具体的损失函数为:
$\mathcal{L}_{\mathrm{KD}}=\mathbb{E}_{\boldsymbol{x} \sim p_{\mathrm{data}}(x)}\left[D_{\mathrm{KL}}(T(\boldsymbol{x}) \| S(\boldsymbol{x}))\right]
$
式中,$D_{KL}(·\|·)
$指KL 散度损失函数,描述的是教师网络和学生网络的输出的差异, 指训练数据的分布,这里的训练数据和 DAFL 一样后续通过 GAN 来生成。通过引入教师学生算法,作者解决了生成图片没有标签的问题,并且可以在待压缩网络结构未知的情况下对其进行压缩。
2.12 通过 GAN 生成无标注的训练图片
从训练数据的角度看,在整个网络压缩的过程中,我们都没有任何给定的训练数据,在此情况下,神经网络的压缩变得十分困难。所以作者通过 GAN 来输出一些无标注的训练图片,以便于神经网络的压缩。生成对抗网络 (GAN) 是一种可以生成数据的方法,包含生成网络G与判别网络D,生成网络希望输出和真实数据类似的图片来骗过判别器,判别网络通过判别生成图片和真实图片的真伪来帮助生成网络训练。
这个基本的流程和 DAFL 是一致的,但是本文的目标函数设计与 DAFL 有差别。
1) Inceptionism loss
这个损失函数设计来自于 Inceptionism: Going deeper into neural networks 这篇论文。Inceptionism-style 图像生成,又叫做 DeepDream,是一种在已训练好的网络情况下,可视化能产生特定输出的输入图片的样子的方法。比如现在有一个训练好的网络,我们想知道什么样的图片可以让这个网络分类为 "狗"。怎么做呢?首先用随机噪声初始化一个可训练的图片 trainable image,不断地更新其参数,使得这个已有的网络的输出与 "狗" 这个类越接近越好,也就是训练这个 trainable image,让网络输出更像 "狗"。
由于生成图片并没有一个真实的标签,将其输出最大值对应的标签设定为它的伪标签。伪标签交叉熵损失的意思就是对于一张生成的图片,它的标签就按照教师网络的输出来决定。
为了使得生成的图片更加接近自然图片,作者另外加了一个先验,比如相邻像素之间的特定联系,体现在加了一个归一化损失,最终Inceptionism loss如下:
$\mathcal{L}_{\text {Inc }}(x, \hat{y})=\mathcal{L}_{\mathrm{CE}}(x, \hat{y})+\lambda_{t} \mathcal{L}_{t}(x)+\lambda_{\ell_{2}} \mathcal{L}_{\ell_{2}}(x)
$
式中,$ \mathcal{L}_{t}
$是指 total variation loss,$\mathcal{L}_{2}
$是指$l_2
$范数,其加权就是这个归一化损失,$\mathcal{L}_{CE}
$是标准交叉熵损失。
2) Moment matching loss
Inceptionism loss 只对教师网络的输出做了约束,现在还没有对于中间层的约束。
作者考虑到了 Batch Norm 层,它可以帮助提供这些中间特征。Batch Norm 操作一般是通过滑动平均的均值和方差 居中 和 重新缩放 来归一化中间层的输出,所以说 Batch Norm 层其实隐式地存储了输入数据的一些信息。利用生成图片x的均值和方差构建loss,训练得到的生成器可以和真实的样本在待压缩网络具有类似的相应,从而更接近真实样本,且生成的数据分布十分均匀。
2.13 使用多个生成器
模式坍塌是 GAN 网络中的一个很严重的问题,生成器 往往不会输出各式各样的生成图片,反而会产生很单一的样式的图片,且不会随着输入隐变量的变化而变化。
为什么在这个任务中会产生模式坍塌的问题?一种可能的解释是:生成器持续地产生一种图片,这种图片经过教师网络会输出一个很像 One-hot 编码的结果,这个结果就会导致 消失了,即使其他的损失函数还没有优化到最优值。下图10就是一个典型的模式坍塌的例子,生成器 生成了很多的自然图片,但是无一例外都是红色的车。
图10:产生模式坍塌的问题
为了解决模式坍塌的问题,训练多个 Generator 是一种简单有效的解决模式坍塌问题的方法。每个 Generator 只生成一个类的数据,比如说在训练第1个 Generator 时,它的训练目标就是使得教师网络的输出和 越接近越好。生成的图片如下图所示。
图11:训练多个 Generator解决模式坍塌的问题
小结
本质上和 DAFL 的两个阶段是一致的,都是先用生成器 输出的生成图片代替训练数据集进行训练,然后使用生成数据结合蒸馏算法得到压缩后的网络。本文方法通过 Inceptionism loss 和 Moment matching loss,以及训练多个生成器来解决 DAFL 无法在大数据集 ImageNet 上使用的问题。