GAN (Generative Adversarial Network)
There are two Models in GAN: - G: Generator and - D: Discriminator
Train Loss Function
The generator will generate some Fake data to cheat Discriminator, the loss function is designed such that the probability to succeed to cheat D is high.
While for D, there are two inputs, one is the real data, which D should think it is Real, while for input any data generated from G (no matter how real when human will look at) D should always regards it as Fake. The loss function is designed for above two factors.
A real example
In following repository:
https://github.com/deNsuh/segan-pytorch
A GAN is implemented to improve the speech quality (to reduce the noise).
The core part for the training is in model.py:
Now for the D, there are two input, one is from dataset which is real data, the
expected output would be 1.0 as it is real; Another input is the output from
generator, although generator could generate likely pleasant result it is
actually fake (noisy), so the expected output would be 0.0.
The loss function is the average of above two losses.
##### TRAIN D #####
# TRAIN D to recognize clean audio as clean
# training batch pass
outputs = discriminator(batch_pairs_var, ref_batch_var) # out: [n_batch x 1]
clean_loss = torch.mean((outputs - 1.0) ** 2) # L2 loss - we want them all to be 1
# TRAIN D to recognize generated audio as noisy
generated_outputs = generator(noisy_batch_var, z)
disc_in_pair = torch.cat((generated_outputs.detach(), noisy_batch_var), dim=1)
outputs = discriminator(disc_in_pair, ref_batch_var)
noisy_loss = torch.mean(outputs ** 2) # L2 loss - we want them all to be 0
d_loss = 0.5 * (clean_loss + noisy_loss)
After loss is defined, we could update the weights.
# back-propagate and update
discriminator.zero_grad()
d_loss.backward()
d_optimizer.step() # update parameters
Similarly, when training generator, the generated output from G will be feed to
D. The expected output for G would be 1.0 (in contrast to 0.0 for D).
##### TRAIN G #####
# TRAIN G so that D recognizes G(z) as real
z = sample_latent()
generated_outputs = generator(noisy_batch_var, z)
gen_noise_pair = torch.cat((generated_outputs, noisy_batch_var), dim=1)
outputs = discriminator(gen_noise_pair, ref_batch_var)
g_loss_ = 0.5 * torch.mean((outputs - 1.0) ** 2)
# L1 loss between generated output and clean sample
l1_dist = torch.abs(torch.add(generated_outputs, torch.neg(clean_batch_var)))
g_cond_loss = g_lambda * torch.mean(l1_dist) # conditional loss
g_loss = g_loss_ + g_cond_loss
Update the weights for G:
# back-propagate and update
generator.zero_grad()
g_loss.backward()
g_optimizer.step()