I was planning to write a post after getting everything done but it took longer than what I like…
We are working on some generative models now and for the last week I was learning how to train a Generative Adversarial Network. GANs are super cool, but also notoriously difficult to train! They probably aren’t the best choice for a machine learning newbie like me.
But doing something challenging is fun! Using something that works out of the box is just too lame. And if it already works well, it likely isn’t worth studying.
Here I’d like to talk about what I’ve learned so far. Let’s see the result first. I’m getting something like this from a GAN trained on MNIST dataset:
It is generating something. Not exactly digits but do resemble some hand written script…
I found that there are some practical issues not addressed in the DCGAN paper. It suggested using batch normalization. Batchnorm behaves differently when training and when evaluating. When training. it uses the mean and variance of the training batch. When evaluating, it uses the statistics of all the examples it has seen. It is logical to use evaluating mode for the generator when training the discriminator and vice versa.
In the original GAN paper, in each discriminator training round, the “true batch” from real data and the “false batch” from generated examples are fed into the discriminator separately. It is logical to make the distribution of each batch the same, so true and false examples should be mixed in the same batch.
To avoid making the wrong decision for the previous two issues I tried both ways round. But the suggestions given in the DCGAN paper still didn’t quite work out for me. I’ve followed them as much as I can. I can’t do the fractional-strided convolution for the generator network because that kind of convolution is not readily available in torch… But otherwise I did exactly what was advised. But still, the generator network keeps collapsing every input to a single output. According to what I read, this indeed is the most observed failure mode of GAN.
Then I came up with a not-so-elegant trick to solve this problem. The generator collapses because it thinks that output is the single best image to trick the discriminator into mistaking it for a real example from the dataset and learns to produce that output from any input. But then the discriminator quickly learns that that particular example is fake. Then the generator looks for the next output that can trick the discriminator…
This happens because the discriminator looks at a single example at a time. If we can somehow let the discriminator reject a batch if every example in the batch looks the same, then the generator should not be tempted to collapse every input into the same output anymore! How do we do this? My solution is, for each batch, take their mean. Then for each example in the batch, take its difference from the mean and concatenate the result with itself by adding them as additional channels. It is expected that if the batch is similar then the intensity in the additional channel would be small and the discriminator would be able to learn it.
But then there is a conflict with batchnorm. Since we want the discriminator to tell the difference between a true batch and a false bath, we cannot mix them in the same batch! But on the other hand, the difference of intensity between a true batch and a false batch is what makes the additional channels useful, with batchnorm, the differences are wiped out! To prevent this, we cannot let batchnorm normalize the batches individually and have to mix them!
So do we not use batchnorm? But without batchnorm, the training becomes very unstable.
The solution is to mix the two batches after their additional channels have been calculated separately.
This method actually works! the result is not perfect, but at least the generator does not collapse anymore! In the training process, you can see that at some point, the generator almost collapses but then the output soon starts to diverge.
At this point I’m not quite sure how to get the rest of things right. Tuning training parameters, or go model shopping?
And then there is one last mystery. The GAN paper says we should train the discriminator more. But in reality the discriminator constantly beat the crap out of the generator even if I train the generator 10 times as hard! Is this normal?
Actually the difficulty of balancing the training of the generator and the discriminator is another major factor that makes GAN hard to train, along with generator collapsing. The third factor probably is that since there are two networks competing, there is no single loss function to measure how well the training is going.
I’ll probably write something more elaborate if I do become better at training GANs.
And what is the one thing that I want to generate with a GAN?
Image of anime style eyes!