GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
- Freeze the generator and train the critic for one step by:
- getting one batch of true images (let's call that
real
) - generating one batch of fake images (let's call that
fake
) - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the critic with the gradients of this loss
- getting one batch of true images (let's call that
- Freeze the critic and train the generator for one step by:
- generating one batch of fake images
- evaluate the critic on it
- return a loss that rewards positively the critic thinking those are real images
- update the weights of the generator with the gradients of this loss
This is just a shell to contain the two models. When called, it will either delegate the input to the generator
or the critic
depending of the value of gen_mode
.
By default (leaving gen_mode
to None
), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])
tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)
tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])
If the generator
method is called, this loss function expects the output
of the generator and some target
(a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func
. This loss function has the following signature
def gen_loss_func(fake_pred, output, target):
to be able to combine the output of the critic on output
(which the first argument fake_pred
) with output
and target
(if you want to mix the GAN loss with other losses for instance).
If the critic
method is called, this loss function expects the real_pred
given by the critic and some input
(the noise fed to the generator). It will evaluate the critic using crit_loss_func
. This loss function has the following signature
def crit_loss_func(real_pred, fake_pred):
where real_pred
is the output of the critic on a batch of real images and fake_pred
is generated from the noise using the generator.
We use the generate_noise
function to generate noise vectors to pass into the generator for image generation.
bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
learn.show_results(max_n=9, ds_idx=0)