An Intuitive way to understand GANs
Updated: Jun 5, 2021
You must have heard a lot of buzzes that neural networks can create fake data such as images and audio clips which are almost impossible for humans to identify that it is not real-world captured data rather created by an algorithm! Well at some point, the geek in you would have wondered, how does it work? So, in this article, you will gain an intuition for the algorithm, I will not be digging the mathematics behind it because you all know it's a story for another day.
Your first introduction to GANs
It all began when Ian Goodfellow also known as the father of GAN released his first-ever paper on Generative Adversarial Network (GAN) in 2014. GAN had a new architecture that was not prevalent till then. They fall into the category of generative models, i.e. the model can produce images that never existed before. Traditional models like classifier or regression cannot accomplish with its limited architecture, so GANs adds a new twist to it. Even with the twist, the model was only able to produce black and white images with low resolution as you can see below :
Source : https://arxiv.org/pdf/1406.2661.pdf
Now, this is not so impressive right? From 2014 GAN's have come a long way in their performance. One of the state-of-the-art models called StyleGAN can produce such realistic-looking cats, which makes me sad that I don't get to play with them.
A Sneak Peek of GANs working
I'll start with an analogy of a student trying to learn painting and a tutor giving him feedback on areas to improve. GAN has two separate neural networks called Generator which acts like the student and another network called Discriminator/Critic which acts as the tutor. In one line, we could say the generator generates data and the discriminator gives feedback on how real the generated image looks as compared to a real image. Now training such a system is not straightforward, again taking the same analogy, if the student was learning mathematics instead of painting, it's easy to teach him as we have a clear answer of what is correct and wrong, and you can specifically point out where the answer went wrong, whereas in painting there is no "one" correct answer. Similarly, it's hard to pinpoint in GANs where exactly is the generator failing and to define the "correct" output. For example, take all the above images of cats, there is no correct image of a painting of a cat, all the images do their justice of being a cat. You also want your students to be able to draw a variety of paintings, imagine you put all your efforts to teach the student how to paint and in his whole life he is just able to draw a single painting, well that is not definitely something we would like. In other words, we want our generator to produce images with rich diversity.
Therefore, training GANs are comparatively harder compared to other models like a classifier, regression, etc.
The Training Game
Now let's dig a bit deeper into the system's working. To have diversity in the generated in the images we sample a random noise from distribution and then input it to the generator. The generator usually consists of convolution layers, any type of normalization such as batch, layers, AdaIn, etc. and activation functions. The architecture varies from as simple as a couple of linear layers, to encoder-decoder and recent ones including the self-attention mechanism. Now with any of the mentioned architecture, it generates a fake image. Then this generated image is fed to the discrimination where it gives a score between zero and one stating how fake or real the output is respectively. The discriminator is a pretty straightforward classifier trying to classify between fake and real images. In the big picture, there is a minmax game going on, where the generator wants to get better at convincing the discriminator that the generated image is real, and at the same time, the discriminator gets better at classifying while trying to not get fooled by the generator. You can roughly say that training has come to an edge when the discriminator thinks the fake images are real as demonstrated in the diagram below.
First, the generator outputs the fake image, then this is passed through the discriminator to get a fake_score. We want the discriminator to identify as fake i.e., the score to be as close to zero, therefore the loss is calculated against zero, let's call it fake_loss. Then show the discriminator a real image and get a real_score. We want it to identify as real i.e the score to be close to one, therefore the loss is calculated against one, let's call it real_loss.
Now take an average of both the loss, and this is the final loss. Hence, we are teaching it to classify real and fake. Then with the final loss backpropagation takes place and only parameters of the discriminator are updated not the generator.
PyTorch code for updating the discriminator is given below :
disc_opt.zero_grad() with torch.no_grad(): fake = gen(sketch) disc_fake_hat = disc(fake.detach(), sketch) disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat)) disc_real_hat = disc(real, sketch) disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat)) disc_loss = (disc_fake_loss + disc_real_loss) / 2 disc_loss.backward(retain_graph=True) disc_opt.step()
Taking the same generated fake image, pass it to the discriminator to get a score between zero and one. We want this score to be as close to one as possible because the generator aims to fool the discriminator, hence the generator learns in the path which leads the discriminator to give a score of one. The loss is calculated against one and using this loss only generator parameters are updated whereas discriminator parameters are detached.
PyTorch code for updating the generator is given below :
output = disc(fake).reshape(-1) loss_gen = criterion(output, torch.ones_like(output)) gen.zero_grad() loss_gen.backward() opt_gen.step()
Why is this the optimal way to train?
An interesting thing to note that both the networks are trained simultaneously i.e at a given time in training time, both have the same "IQ". You can say both the networks are clueless at the beginning the generator wouldn't know how to generate realistic images at first and generates almost noisy data. Even the discriminator thinks well this is not so fake, it's kind of close to real and gives a score that is not close to zero. In short, it's easy to fool the discriminator. Then eventually as the training proceeds, each starts getting slightly better step by step.
Let's see what happens if both the networks do not have the same "IQ":
When the generator is ahead of the discriminator: If it is very good at producing realistic images, then every time discriminator will assume that it's a real image, giving a score of one. Hence there is no feedback for the generator to improve and it can be stuck at the same point generating the same images every time.
When the discrimination is ahead of the generator: If it is very good at classifying as fake or real, then every time the score for the generated fake image will be straight zero or close to zero, well this is not so useful for the generator to improve upon. Hence again it becomes hard to increase the quality of the fake image.
The unappreciated story of noise vector
Let's see how noise vector indirectly controls small aspects in the generated images.
First, let us see how we represent noise. Noise is just an N-dimensional vector with random numbers, derived from the distribution as represented in the above diagram. This is where the GAN gets its capability to generate a diverse set of data. Intuitively, when you change the noise vector i.e sample from a different point in the distribution, you get newly generated data that is different from the previously generated data.
Initially, when the training just begins there is no correlation between the N-dimensions and details on the generated images. However, this does not stay the same over the period of training. You like it or not the generator learns to map each dimension in the vector to a particular feature in the generated data, for example, if you are trying to generate images of human faces then tweaking a certain dimension let's say the 7th dimension is correlated with the hairstyle, so when you change the value of the 7th dimension and keep other values of the vector same you can get different hairstyles having the same face and other features. I believe you got the rough idea now, however, it's not straightforward as I described above the example was just for intuition. The correlation is usually not perceivable by humans, the generator learns the mapping with its own interpretation. Also, it's not possible to pinpoint that this Nth dimension is what changes this feature in the output. Sometimes if the vector doesn't have enough dimensions to map each feature to a dimension, then the features get entangled with more than one dimension. Practically speaking taking the former example, if you try to change the hairstyle you might end up tweaking the eye colour as well! Note that there is no way you control or manipulate the mapping process, you just sit back and enjoy! I believe that was enough insights into the noise vector.
Yann LeCun, Facebook’s chief AI scientist, has called GANs “the coolest idea in deep learning in the last 20 years.” This must-have caught your attention and the application in industries such as gaming, video conferencing, political campaign, etc. are increasing rapidly.
Now I hope you got an intuition for GANs and it's working! Go check out research papers and journals to know in-depth and its advancements. Happy learning!