GAN

GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:

  1. Freeze the generator and train the critic for one step by:
  1. Freeze the critic and train the generator for one step by:
Note

The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.

Wrapping the modules


source

GANModule


def GANModule(
    generator:nn.Module=None, # The generator PyTorch module
    critic:nn.Module=None, # The discriminator PyTorch module
    gen_mode:None | bool=False, # Whether the GAN should be set to generator mode
):

Wrapper around a generator and a critic to create a GAN.

This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.


source

GANModule.switch


def switch(
    gen_mode:None | bool=None, # Whether the GAN should be set to generator mode
):

Put the module in generator mode if gen_mode is True, in critic mode otherwise.

By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).


source

basic_critic


def basic_critic(
    in_size:int, # Input size for the critic (same as the output size of the generator)
    n_channels:int, # Number of channels of the input for the critic
    n_features:int=64, # Number of features used in the critic
    n_extra_layers:int=0, # Number of extra hidden layers in the critic
    norm_type:NormType=<NormType.Batch: 1>, # Type of normalization to use in the critic
    ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2, bn_1st:bool=True,
    act_cls:type=ReLU, transpose:bool=False, init:str='auto', xtra:NoneType=None, bias_std:float=0.01,
    dilation:Union=1, groups:int=1, padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->nn.Sequential:

A basic critic for images n_channels x in_size x in_size.


source

AddChannels


def AddChannels(
    n_dim
):

Add n_dim channels at the end of the input.


source

basic_generator


def basic_generator(
    out_size:int, # Output size for the generator (same as the input size for the critic)
    n_channels:int, # Number of channels of the output of the generator
    in_sz:int=100, # Size of the input noise vector for the generator
    n_features:int=64, # Number of features used in the generator
    n_extra_layers:int=0, # Number of extra hidden layers in the generator
    ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2,
    norm_type:NormType=<NormType.Batch: 1>, bn_1st:bool=True, act_cls:type=ReLU, transpose:bool=False,
    init:str='auto', xtra:NoneType=None, bias_std:float=0.01, dilation:Union=1, groups:int=1,
    padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->nn.Sequential:

A basic generator from in_sz to images n_channels x out_size x out_size.

critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])

source

DenseResBlock


def DenseResBlock(
    nf:int, # Number of features
    norm_type:NormType=<NormType.Batch: 1>, # Normalization type
    ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2, bn_1st:bool=True,
    act_cls:type=ReLU, transpose:bool=False, init:str='auto', xtra:NoneType=None, bias_std:float=0.01,
    dilation:Union=1, groups:int=1, padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->SequentialEx:

Resnet block of nf features. conv_kwargs are passed to conv_layer.


source

gan_critic


def gan_critic(
    n_channels:int=3, # Number of channels of the input for the critic
    nf:int=128, # Number of features for the critic
    n_blocks:int=3, # Number of ResNet blocks within the critic
    p:float=0.15, # Amount of dropout in the critic
)->nn.Sequential:

Critic to train a GAN.


source

GANLoss


def GANLoss(
    gen_loss_func:Callable, # Generator loss function
    crit_loss_func:Callable, # Critic loss function
    gan_model:GANModule, # The GAN model
):

Wrapper around crit_loss_func and gen_loss_func


source

GANLoss.generator


def generator(
    output, # Generator outputs
    target, # Real images
):

Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output


source

GANLoss.critic


def critic(
    real_pred, # Critic predictions for real images
    input, # Input noise vector to pass into generator
):

Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.

If the generator method is called, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature

def gen_loss_func(fake_pred, output, target):

to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).

If the critic method is called, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature

def crit_loss_func(real_pred, fake_pred):

where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.


source

AdaptiveLoss


def AdaptiveLoss(
    crit:Callable
):

Expand the target to match the output size before applying crit.


source

accuracy_thresh_expand


def accuracy_thresh_expand(
    y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True
):

Compute thresholded accuracy after expanding y_true to the size of y_pred.

Callbacks for GAN training


source

set_freeze_model


def set_freeze_model(
    m:nn.Module, # Model to freeze/unfreeze
    rg:bool, # `Requires grad` argument. `True` for freeze
):

source

GANTrainer


def GANTrainer(
    switch_eval:bool=False, # Whether the model should be set to eval mode when calculating loss
    clip:None | float=None, # How much to clip the weights
    beta:float=0.98, # Exponentially weighted smoothing of the losses `beta`
    gen_first:bool=False, # Whether we start with generator training
    show_img:bool=True, # Whether to show example generated images during training
):

Callback to handle GAN Training.

Warning

The GANTrainer is useless on its own, you need to complete it with one of the following switchers


source

FixedGANSwitcher


def FixedGANSwitcher(
    n_crit:int=1, # How many steps of critic training before switching to generator
    n_gen:int=1, # How many steps of generator training before switching to critic
):

Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.


source

AdaptiveGANSwitcher


def AdaptiveGANSwitcher(
    gen_thresh:None | float=None, # Loss threshold for generator
    critic_thresh:None | float=None, # Loss threshold for critic
):

Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.


source

GANDiscriminativeLR


def GANDiscriminativeLR(
    mult_lr:float=5.0
):

Callback that handles multiplying the learning rate by mult_lr for the critic.

GAN data


source

InvisibleTensor


def InvisibleTensor(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

TensorBase but show method does nothing


source

generate_noise


def generate_noise(
    fn, # Dummy argument so it works with [`DataBlock`](https://docs.fast.ai/data.block.html#datablock)
    size:int=100, # Size of returned noise vector
)->InvisibleTensor:

Generate noise vector.

We use the generate_noise function to generate noise vectors to pass into the generator for image generation.

bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

GAN Learner


source

gan_loss_from_func


def gan_loss_from_func(
    loss_gen:Callable, # A loss function for the generator. Evaluates generator output images and target real images
    loss_crit:Callable, # A loss function for the critic. Evaluates predictions of real and fake images.
    weights_gen:None | MutableSequence | tuple=None, # Weights for the generator and critic loss function
):

Define loss functions for a GAN from loss_gen and loss_crit.


source

GANLearner


def GANLearner(
    dls:DataLoaders, # DataLoaders object for GAN data
    generator:nn.Module, # Generator model
    critic:nn.Module, # Critic model
    gen_loss_func:Callable, # Generator loss function
    crit_loss_func:Callable, # Critic loss function
    switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to [`FixedGANSwitcher`](https://docs.fast.ai/vision.gan.html#fixedganswitcher)
    gen_first:bool=False, # Whether we start with generator training
    switch_eval:bool=True, # Whether the model should be set to eval mode when calculating loss
    show_img:bool=True, # Whether to show example generated images during training
    clip:None | float=None, # How much to clip the weights
    cbs:Callback | None | MutableSequence=None, # Additional callbacks
    metrics:None | MutableSequence | Callable=None, # Metrics
    loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
    opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
    lr:float | slice=0.001, # Default learning rate
    splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
    path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
    model_dir:str | Path='models', # Subdirectory to save and load models
    wd:float | int | None=None, # Default weight decay
    wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
    train_bn:bool=True, # Train frozen normalization layers
    moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
    default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):

A Learner suitable for GANs.


source

GANLearner.from_learners


def from_learners(
    gen_learn:Learner, # A [`Learner`](https://docs.fast.ai/learner.html#learner) object that contains the generator
    crit_learn:Learner, # A [`Learner`](https://docs.fast.ai/learner.html#learner) object that contains the critic
    switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to [`FixedGANSwitcher`](https://docs.fast.ai/vision.gan.html#fixedganswitcher)
    weights_gen:None | MutableSequence | tuple=None, # Weights for the generator and critic loss function
    gen_first:bool=False, # Whether we start with generator training
    switch_eval:bool=True, # Whether the model should be set to eval mode when calculating loss
    show_img:bool=True, # Whether to show example generated images during training
    clip:None | float=None, # How much to clip the weights
    cbs:Callback | None | MutableSequence=None, # Additional callbacks
    metrics:None | MutableSequence | Callable=None, # Metrics
    loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
    opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
    lr:float | slice=0.001, # Default learning rate
    splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
    path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
    model_dir:str | Path='models', # Subdirectory to save and load models
    wd:float | int | None=None, # Default weight decay
    wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
    train_bn:bool=True, # Train frozen normalization layers
    moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
    default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):

Create a GAN from learn_gen and learn_crit.


source

GANLearner.wgan


def wgan(
    dls:DataLoaders, # DataLoaders object for GAN data
    generator:nn.Module, # Generator model
    critic:nn.Module, # Critic model
    switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to `FixedGANSwitcher(n_crit=5, n_gen=1)`
    clip:None | float=0.01, # How much to clip the weights
    switch_eval:bool=False, # Whether the model should be set to eval mode when calculating loss
    gen_first:bool=False, # Whether we start with generator training
    show_img:bool=True, # Whether to show example generated images during training
    cbs:Callback | None | MutableSequence=None, # Additional callbacks
    metrics:None | MutableSequence | Callable=None, # Metrics
    loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
    opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
    lr:float | slice=0.001, # Default learning rate
    splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
    path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
    model_dir:str | Path='models', # Subdirectory to save and load models
    wd:float | int | None=None, # Default weight decay
    wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
    train_bn:bool=True, # Train frozen normalization layers
    moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
    default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):

Create a WGAN from dls, generator and critic.

from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch train_loss gen_loss crit_loss time
0 -0.815071 0.646809 -1.140522 00:38
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)