相关理论
对抗自编码器(AAE)是AE+GAN的生成模型,可以说是VAE和GAN的一种改进模型。
首先介绍一下生成式模型的基本方法,下图是 Goodfellow 2016年总结的和神经网络相关生成式方法的“家谱”。以极大似然为根节点,按照是否需要定义概率密度函数划分为:明确概率、没有明确概率。
明确概率又可以进一步划分为:准确求解、近似求解。准确求解就是可以直接通过数学方法来建模求解。但是在神经网络中,准确求解可能会出现时间复杂度极高、高维不可求、无法微分等问题。此时就需要通过近似求解来结果这种问题。近似求解可以分为:确定近似方法和随机近似方法。确定近似方法主要使用变分推断,采用最大化变分下界(确定的公式)的方法来求解,比如VAE就是使用变分推断求出变分下界,将潜在变量空间Z构造成高斯分布。随机方法也就是马尔科夫链蒙特卡洛采样这类算法,比如BM、RBM、DBN、DBM等,主要以MCMC算法为主。变分推断和蒙特卡洛采样其实挺相似的。在蒙特卡洛采样中,采样大量的样本来拟合真实数据的分布,而在变分推断中,使用简单分布来拟合真实数据分布。但是马尔科夫链蒙特卡洛采样速度慢并且不准确。
没有明确概率又可以分为:马尔科夫链蒙特卡洛采样、GAN两部分。今天要讲解的AAE就属于GAN的一种。
由于马尔科夫链蒙特卡洛采样需要进行大量的采样来拟合真实数据的分布,导致速度慢并且不准确。所以现在比较火的生成模型就是:VAE、GAN、pixel,在训练时可以直接采用BP算法,克服了MCMC训练存在的不足之处。
AAE
下面简单介绍一下AAE的模型。AAE是AE+GAN的生成模型,可以说是VAE和GAN的一种改进模型。
为什么是VAE的改进模型:VAE的目的是通过变分推断的方法使得潜在变量空间$z$服从多维高斯分布,而AAE的目的也是使得潜在变量空间$z$服从多维高斯分布。可以说说殊途同归的两种模型;
为什么是GAN的改进模型:AAE通过加入对抗网络的方法使得潜在变量空间$z$服从多维高斯分布;
AAE由两部分构成:
AE(编码器、解码器):对于所有的自编码器来说其目标都是样本重构。自编码器首先通过编码器把高維空间中的向量$X$,压缩成低維向量$Z$(潜变量),然后通过解码器把低維向量解压重构出$X$,然后使用均方差损失通过BP进行训练。如该图上半部分所示,编码器首先把一个样本手写字体1通过编码器压缩成一个低維向量$z$,然后通过解码器重构原始样本。
GAN(生成器、判别器):判别器用于区分真样本和假样本,生成器用于生成假样本。AE的编码器相当于GAN的生成器。
再介绍一下模型的变量:
$x$:输入样本数据向量;
$q(z|x)$:编码器分布函数;
$z$:潜在变量空间,编码出来的潜在变量,服从$z \sim q(z)$分布;
$q(z)$:AAE模型学习出来的分布;
$p(z)$:自己定义的任意先验分布,通常$N(0,I)$;
AAE的训练涉及到AE的训练和GAN的训练。AE训练比较简单,使用均方差作为损失函数,通过BP算法进行训练。由于涉及到GAN,下面来复习一下GAN。
GAN
下图是GAN的损失函数。下面以生成图片为例来说明一下GAN的公式和原理。
GAN由两部分构成:生成器、判别器。首先从左到右解释一下这个公式。$G$表示生成器,$D$表示判别器,$x$表示真实数据,$P_{data}(x)$表示真实数据的分布,$D(x)$表示判别器判断真实数据是否真实的概率,$z$表示随机噪声,$P_z(z)$表示噪声数据的分布,$G(z)$表示随机噪声通过生成器后生成的假图片、$D(G(z))$是判别器判断生成器生成的图片是否真实的概率。
再解释一下为什么要最小化$G$,最大化$D$:
- $G$是生成器,用于生成图片。它接收一个随机的噪声$z$,从噪声$z$中采样通过生成器来生成图片$G(z)$。生成器$G$的目的:$D(G(z))$是判别器判断生成器生成的图片是否真实的概率,生成器$G$应该希望自己生成的图片“越接近真实越好”。也就是说,$G$希望$D(G(z))$尽可能得大,这时$V(D,G)$会变小。所以式子最前面是最小化$G(min_G)$
- $D$是判别器,用于判别一张图片是不是“真实的”。它的输入参数是$x$,$G(z)$。考虑单样本,从真实数据分布中采样得到真样本$x$,经过判别器得到$D(x)$;将假样本$G(z)$输入判别器得到$D(G(z))$。判别器$D$的目的:$D$的能力越强,$D(x)$应该越大,$D(G(x))$应该越小。这时$V(D,G)$会变大。因此式子最前面是最大化$D(max_D)$
在训练过程中,生成器$G$的目标就是尽量生成真实的图片去欺骗判别器$D$。而判别器$D$的目标就是尽量把生成器$G$生成的图片和真实图片区分出来。这样生成器$G$和判别器D构成了一个动态的“博弈过程”。在最理想的状态下,生成器$G$可以生成足以“以假乱真”的图片$G(z)$。对于$D$来说,它难以判定$G$生成的图片究竟是不是真实的,因此$D(G(z)) = 0.5$。
下面给出GAN训练的伪代码:
训练AAE时,在训练AAE的对抗网络时,使用的也是这种训练方法。
下面再通过一张图来直观的理解一下GAN的训练过程。黑色虚线是真实数据的高斯分布,绿色的线是生成器学习到的伪造分布,蓝色的线是判别器判定为真实图片的概率。(标x的横线代表服从高斯分布x的采样空间,标z的横线代表服从均匀分布的采样空间)同理,对于AAE来说,黑色虚线表示真实数据的分布p(z),绿线表示潜在变量空间的分布q(z),蓝色虚线表示经过判别器得到的概率。
AAE训练
对抗自编码器使用SGD算法进行训练,每个批训练可以分成两个过程:
(1)样本重构阶段
SGD更新自编码器的编码器和解码器的参数, 使得损失函数最小化:
(2)正则化约束阶段
A. 首先更新判别器参数, 用于区分真实样本($p(z)$采样的样本,正样本)、自编码编码层生成样本$z(q(z)$负样本)
B. 然后更新生成器参数(自编码器的编码器),以此来提高混淆判别器的能力
和VAE比较
很多时候希望潜变量$z$服从于某个已知的先验分布$p(z)$,比如希望$z$的每个特征相互独立并且符合高斯分布;甚至在机器学习分类问题上,希望这个潜在的表征向量$z$与我们的分类标签有关。
AAE和VAE都可以在潜在变量空间生成高斯分布。在变分自编码器中,我们通过一大堆复杂的公式(如变分推理、参数变换等),进行变分贝叶斯推导,构造变分下界,然后转换成KL散度、重新参数化等构造出损失函数,来使得编码潜变量$z$服从于高斯分布。
下面是变分自编码器的结构图。
对抗自编码的目标也是如此,目的是约束潜在变量空间$Z$服从高斯分布。只不过其不需要一系列复杂的变分推理,而是借助于对抗网络训练的框架,就实现了$Z$的约束,使得$z$服从于高斯分布。
下图对抗自编码器和变分自编码器在手写数字集MNIST上的比较。ABCD是AAE、VAE经过数据降维、数据可视化转化成2维的结果图。每种颜色代表相关的标签。A、C表示z经过数据可视化转换成2维高斯生成的图。B、D表示z经过数据可视化转换成10个2维高斯生成的图。左图意在说明,对抗自编码器可以和变分自编码器一样使得z的分布服从高斯分布,甚至对抗自编码器的效果要优于变分自编码器(图中表现出尖锐的转变,潜在变量空间z施加高斯分布后,空间被填满)。E是经过训练的变分自编码器从潜在变量空间z采样,经过解码器生成的手写数字。
AAE $p(z|x)$选择
$q(z|x)$的3种选择:
- 确定性函数:$q(z|x)$是关于$x$的确定性函数。$q(z)$只与真实数据$x$有关;
- 高斯后验:假设$q(z|x)$是一个高斯分布,其均值和方差由编码器网络预测:$z \sim N(μ,σ)$
- 通用近似后验:$q(z)$由数据分布$p_d(x)$和随机噪声$p_η(η)$决定。
选择不同类型的$q(z|x)$可用于不同的任务。一般情况下后两种的效果要优于第一种。
$(z|x)$选择确定性函数情况下,网络必须通过利用数据分布的随机性来匹配$q(z)$和$p(z)$,但是由于数据的经验分布是固定的训练集,这可能产生不太平滑的$q(z)$。在高斯后验或通用近似后验的情况下,网络可以获得额外的随机性来源,可以产生平滑的$q(z)$
将标签信息纳入对抗网络
可以利用部分或完整标签信息来更加严格地规范自编码器的潜在表示。以10个二维高斯为例,加入标签的目的是强制每种高斯分布代表MNIST的单个标签。标签用one-hot向量表示,每个one-hot向量表示一个数字。加入标签之后训练的结果是每个高斯分布表示的是一类数字。
应用
- 监督对抗自编码器
在上边的对抗自编码器的结构图中,将类标签信息以one-hot向量形式输入解码器,此时潜在变量空间$z$学习的是图片样式信息。解码器利用标签信息和样式信息重构图像,这种结构迫使网络保留独立于隐藏层z的标签的所有信息。
上图左图展示了在MNIST数字集上训练这种网络的结果。每一行重构图像的样式相同(例如字体)。 - 半监督对抗自编码器
此时有两个对抗网络,第一个对抗网络用于约束标签类别分布$p(y)$,第二个对抗网络用于约束潜变量$z$的先验分布$p(z)$。自编码器的编码层$q(z,y|x)$用于预测标签$y$、潜变量$z$, 解码层$p(x|z,y)$输入预测标签$y、z$用于重构样本$X$。
$p(y) = Cat(y), p(z) = N (z|0,I), x \sim P(x|z,y)$
训练过程可以分成三个阶段:
(1)无监督重构阶段
(2)对抗正则化阶段
(3)半监督分类阶段:SGD更新$q(y|x)$以此最小化有标签数据交叉熵损失函数
半监督对抗自编码器在MNIST和SVHN数据集上的半监督分类结果(错误率) - 对抗自编码器的无监督聚类
对抗自编码器可以以无监督的方式从连续的潜在样式特征中分离出离散的分类特征。下图表示使用具有16个标签MNIST的AAE无监督聚类。每行对应于一个簇,第一个图像是簇头。
无监督聚类,对抗自编码器在MNIST数据集上的聚类效果(错误率)并且随着标签数量的增长,AAE聚类的错误率在降低。
此外,对抗自编码器还可以用于特征提取、数据降维、数据可视化等方面。
总结
AAE是在自编码器中引入对抗网络,从而限制潜在变量空间$z$服从高斯分布,达到和变分自编码器相同的目的。而相较于变分自编码器、AAE灵活性更好,理论上可以用潜在变量空间$z$的分布来拟合任意分布(离散或连续),并在数据生成、聚类、数据降维、特征提取、数据可视化方面取得了较好的效果。