吐个槽先……我感觉我的英文写作比中文好,但写英文还是挺费脑子的,写中文就可以很随便……毕竟中文是母语……
这是我第一次训练神经网络(跑别人写的代码不算),一周都在琢磨怎么把GAN训练出来,今天不想费脑子了,而且好像正好关于GAN的中文资料也很少,就用中文写了
Generative Adversarial Network,生成式对抗网络,是一种生成模型(废话)。基于神经网络的生成模型很常见,但主要是在序列形式的数据上训练RNN,比如生成文本啥的。GAN采取了不同的思路:对抗训练。两个神经网络:一个生成网络,学习从随机噪声向量产生与训练数据相似的样本,和一个判别网络,学习判别一个样本是来自训练数据还是来自生成网络。对抗的目标是,生成网络提高判别误差,判别网络降低判别误差。
对抗训练也不算是新思路了,比如让两个AI对弈或者对话啥的很久以前就有,但用于生成模型算是个比较新的用法。
训练过程也很简单:对生成网络和判别网络的训练交替进行,每一回合,首先从训练数据sample一个batch,标记为真,然后使用生成网络生成一个batch,标记为假,用这两个batch训练判别网络;再使用生成网络生成一个batch,标记为真,让判别网络进行判别,然后把误差反向传播到生成网络进行训练。
虽然思路很简单,但实际操作的时候还是比较tricky的,很难训练好。原因有几个:首先,生成网络的质量无法量化(使用判别误差显然是不行的),只能靠主观判断,而凭主观判断是难以看出训练到底进行到什么程度、有没有卡住的,也就导致很难找到合适learning rate。
其次,两个网络的训练速度是很不同的。为了确保两个网络共同进步,要使它们保持能力相当。原paper认为判别网络应当多训练。但我的实测结果是在使用相同的learning rate时判别网络对生成网络基本上是吊打,不知道是什么情况。如果尝试保持每一回合结束后的判别误差约等于random guess的话,每个回合大概要训练生成网络20次以上……我还没有做实验,不过感觉可能是和输入噪声变量的数量有关。
不过上面两点都是次要的。GAN最常见的失败模式是生成网络对所有输入给出相同的输出。这个很好理解:每一个回合,生成网络的最优策略当然是找到当前判别网络表现最差的那一个样本,然后对所有输入都给出那个输出。不过接下来判别网络马上就会发现这个样本是假的。然后下一个回合生成网络又会找到另一个判别网络表现最差的样本,就这样这个最差样本在样本空间里变来变去,两个网络捉迷藏……
这大概也是原paper认为判别网络应该多训练的原因:生成网络不可能一步就使得所有输入都输出那个最差样本,如果判别网络多训练,及时发现生成网络的输出变化趋势并将这个样本判别为假,让这个最差样本在样本空间里跑得比较快,生成网络就追不上,也就不会对所有输入都给出相同输出了。
同样道理,learning rate不能太大。
但实际操作的时候,还是总是会发生这种生成网络collapse的情况……
我是在MNIST手写数字数据集上尝试训练的。网络结构什么的……这么简单的数据集其实随便什么结构都好吧……判别网络输入接两层卷积,接一层全连接,接输出,使用batch bormalization。生成网络正好反过来。
在调整网络结构、调整learning rate、调整两个网络的训练速度比均无果之后,我想了个不太优美的办法。判别网络每次只判别一个样本。如果可以一次判别整个batch呢?虽然每一个样本都像是真的,但整个batch长得一样,明显就是假的嘛……
所以怎样让判别网络在判别单个样本的时候可以参照整个样本的信息?一个方法是,求整个batch的平均值,然后将每个样本与平均值的差值作为一个额外的channel。如果整个batch很相似,这个channel的值(的绝对值)会比较小,否则会比较大。
加了这个trick之后,还真的就训练出来了……看一下训练过程,可以发现生成网络还是会有将所有输入收敛到同一输出的趋势,但马上就要collapse的时候,判别网络习得了真batch和假batch的统计差异,然后再过几个回合,生成网络的输出就发散了。
看一下结果。以下是生成网络生成的数字:
以下是真实数据:
还不错吧。不过也有一些明显的缺点:数字分布不均匀(1明显太多而8明显太少),以及有一些四不像的输出。再看看训练过程中输出的演化:
有一些很快就稳定了,还有一些跳来跳去,大概是位于输出不同数字的输入区域的边界上。
再看看其他好玩的事情,比如输入的128个标准正态分布随机变量都是干啥的。事实证明,不太好玩……由于128个太多了,而MNIST数据集很简单,没那么多的变化自由度,所以大部分变量只做了一点微不足道的工作,即使能看出来也难以解释到底是个啥作用……举个例子,以下每一组输出中只有一个输入从-2变化到2,其余变量固定:
以下每组是随机两个输入之间的线性插值:
以下是四个输入之间的双线性插值:
针对在不同数字生成区域边界上存在的四不像输出,有可以改进的办法:由于MNIST是有class label的,可以做有监督学习:将class label转换成one-hot vector放在输入的种子里,确保不同数字生成区域分离,然后将判别网络改成判别真假+分类、生成网络的目标改成增加判别误差、减小分类误差,可以期望产生更好的输出。目前还在尝试中,结果如下
看起来起码数字分布不均匀的问题是解决了……
这一周算是学习,接下来要把GAN投入实战来做本lab的一个项目了。
不过我自己是有其他打算的……GAN这么6的东西,赶紧拿来随机生成萌妹子啊!
想要高清无码大图是不太可能,而且看看一些在ImageNet之类的数据集上训练的结果可以知道,花花草草还行,想让GAN生成动物啥的结果会很猎奇……
但我想生成个眼睛啥的总还是可以做的吧,收集几万张图片也不是问题,有空试试。
细节很详细, Ma 了仔细看….. [赞][赞][赞]
LikeLike
楼主,问一下有代码么?
我训练10万次也没能训练成这个样子。
LikeLike
我可能发现了一个 typo: bormalization -> normalization
LikeLike