GAN

GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:

  1. Freeze the generator and train the critic for one step by:
  1. Freeze the critic and train the generator for one step by:
Note

The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.

Wrapping the modules


source

GANModule

 GANModule (generator:nn.Module=None, critic:nn.Module=None,
            gen_mode:None|bool=False)

Wrapper around a generator and a critic to create a GAN.

Type Default Details
generator Module None The generator PyTorch module
critic Module None The discriminator PyTorch module
gen_mode None | bool False Whether the GAN should be set to generator mode

This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.


source

GANModule.switch

 GANModule.switch (gen_mode:None|bool=None)

Put the module in generator mode if gen_mode is True, in critic mode otherwise.

Type Default Details
gen_mode None | bool None Whether the GAN should be set to generator mode

By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).


source

basic_critic

 basic_critic (in_size:int, n_channels:int, n_features:int=64,
               n_extra_layers:int=0, norm_type:NormType=<NormType.Batch:
               1>, ks=3, stride=1, padding=None, bias=None, ndim=2,
               bn_1st=True, act_cls=<class
               'torch.nn.modules.activation.ReLU'>, transpose=False,
               init='auto', xtra=None, bias_std=0.01,
               dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
               padding_mode:str='zeros', device=None, dtype=None)

A basic critic for images n_channels x in_size x in_size.

Type Default Details
in_size int Input size for the critic (same as the output size of the generator)
n_channels int Number of channels of the input for the critic
n_features int 64 Number of features used in the critic
n_extra_layers int 0 Number of extra hidden layers in the critic
norm_type NormType NormType.Batch Type of normalization to use in the critic
ks int 3
stride int 1
padding NoneType None
bias NoneType None
ndim int 2
bn_1st bool True
act_cls type ReLU
transpose bool False
init str auto
xtra NoneType None
bias_std float 0.01
dilation Union 1
groups int 1
padding_mode str zeros TODO: refine this type
device NoneType None
dtype NoneType None
Returns nn.Sequential

source

AddChannels

 AddChannels (n_dim)

Add n_dim channels at the end of the input.


source

basic_generator

 basic_generator (out_size:int, n_channels:int, in_sz:int=100,
                  n_features:int=64, n_extra_layers:int=0, ks=3, stride=1,
                  padding=None, bias=None, ndim=2,
                  norm_type=<NormType.Batch: 1>, bn_1st=True,
                  act_cls=<class 'torch.nn.modules.activation.ReLU'>,
                  transpose=False, init='auto', xtra=None, bias_std=0.01,
                  dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
                  padding_mode:str='zeros', device=None, dtype=None)

A basic generator from in_sz to images n_channels x out_size x out_size.

Type Default Details
out_size int Output size for the generator (same as the input size for the critic)
n_channels int Number of channels of the output of the generator
in_sz int 100 Size of the input noise vector for the generator
n_features int 64 Number of features used in the generator
n_extra_layers int 0 Number of extra hidden layers in the generator
ks int 3
stride int 1
padding NoneType None
bias NoneType None
ndim int 2
norm_type NormType NormType.Batch
bn_1st bool True
act_cls type ReLU
transpose bool False
init str auto
xtra NoneType None
bias_std float 0.01
dilation Union 1
groups int 1
padding_mode str zeros TODO: refine this type
device NoneType None
dtype NoneType None
Returns nn.Sequential
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])

source

DenseResBlock

 DenseResBlock (nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3,
                stride=1, padding=None, bias=None, ndim=2, bn_1st=True,
                act_cls=<class 'torch.nn.modules.activation.ReLU'>,
                transpose=False, init='auto', xtra=None, bias_std=0.01,
                dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
                padding_mode:str='zeros', device=None, dtype=None)

Resnet block of nf features. conv_kwargs are passed to conv_layer.

Type Default Details
nf int Number of features
norm_type NormType NormType.Batch Normalization type
ks int 3
stride int 1
padding NoneType None
bias NoneType None
ndim int 2
bn_1st bool True
act_cls type ReLU
transpose bool False
init str auto
xtra NoneType None
bias_std float 0.01
dilation Union 1
groups int 1
padding_mode str zeros TODO: refine this type
device NoneType None
dtype NoneType None
Returns SequentialEx

source

gan_critic

 gan_critic (n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)

Critic to train a GAN.

Type Default Details
n_channels int 3 Number of channels of the input for the critic
nf int 128 Number of features for the critic
n_blocks int 3 Number of ResNet blocks within the critic
p float 0.15 Amount of dropout in the critic
Returns Sequential

source

GANLoss

 GANLoss (gen_loss_func:callable, crit_loss_func:callable,
          gan_model:GANModule)

Wrapper around crit_loss_func and gen_loss_func

Type Details
gen_loss_func callable Generator loss function
crit_loss_func callable Critic loss function
gan_model GANModule The GAN model

source

GANLoss.generator

 GANLoss.generator (output, target)

Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output

Details
output Generator outputs
target Real images

source

GANLoss.critic

 GANLoss.critic (real_pred, input)

Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.

Details
real_pred Critic predictions for real images
input Input noise vector to pass into generator

If the generator method is called, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature

def gen_loss_func(fake_pred, output, target):

to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).

If the critic method is called, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature

def crit_loss_func(real_pred, fake_pred):

where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.


source

AdaptiveLoss

 AdaptiveLoss (crit:callable)

Expand the target to match the output size before applying crit.


source

accuracy_thresh_expand

 accuracy_thresh_expand (y_pred:torch.Tensor, y_true:torch.Tensor,
                         thresh:float=0.5, sigmoid:bool=True)

Compute thresholded accuracy after expanding y_true to the size of y_pred.

Callbacks for GAN training


source

set_freeze_model

 set_freeze_model (m:torch.nn.modules.module.Module, rg:bool)
Type Details
m Module Model to freeze/unfreeze
rg bool Requires grad argument. True for freeze

source

GANTrainer

 GANTrainer (switch_eval:bool=False, clip:None|float=None,
             beta:float=0.98, gen_first:bool=False, show_img:bool=True)

Callback to handle GAN Training.

Type Default Details
switch_eval bool False Whether the model should be set to eval mode when calculating loss
clip None | float None How much to clip the weights
beta float 0.98 Exponentially weighted smoothing of the losses beta
gen_first bool False Whether we start with generator training
show_img bool True Whether to show example generated images during training
Warning

The GANTrainer is useless on its own, you need to complete it with one of the following switchers


source

FixedGANSwitcher

 FixedGANSwitcher (n_crit:int=1, n_gen:int=1)

Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.

Type Default Details
n_crit int 1 How many steps of critic training before switching to generator
n_gen int 1 How many steps of generator training before switching to critic

source

AdaptiveGANSwitcher

 AdaptiveGANSwitcher (gen_thresh:None|float=None,
                      critic_thresh:None|float=None)

Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.

Type Default Details
gen_thresh None | float None Loss threshold for generator
critic_thresh None | float None Loss threshold for critic

source

GANDiscriminativeLR

 GANDiscriminativeLR (mult_lr=5.0)

Callback that handles multiplying the learning rate by mult_lr for the critic.

GAN data


source

InvisibleTensor

 InvisibleTensor (x, **kwargs)

TensorBase but show method does nothing


source

generate_noise

 generate_noise (fn, size=100)

Generate noise vector.

Type Default Details
fn Dummy argument so it works with DataBlock
size int 100 Size of returned noise vector
Returns InvisibleTensor

We use the generate_noise function to generate noise vectors to pass into the generator for image generation.

bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

GAN Learner


source

gan_loss_from_func

 gan_loss_from_func (loss_gen:<built-infunctioncallable>,
                     loss_crit:<built-infunctioncallable>, weights_gen:Non
                     e|collections.abc.MutableSequence|tuple=None)

Define loss functions for a GAN from loss_gen and loss_crit.

Type Default Details
loss_gen callable A loss function for the generator. Evaluates generator output images and target real images
loss_crit callable A loss function for the critic. Evaluates predictions of real and fake images.
weights_gen None | collections.abc.MutableSequence | tuple None Weights for the generator and critic loss function

source

GANLearner

 GANLearner (dls:DataLoaders, generator:nn.Module, critic:nn.Module,
             gen_loss_func:callable, crit_loss_func:callable,
             switcher:Callback|None=None, gen_first:bool=False,
             switch_eval:bool=True, show_img:bool=True,
             clip:None|float=None, cbs:Callback|None|MutableSequence=None,
             metrics:None|MutableSequence|callable=None,
             loss_func:callable|None=None,
             opt_func:Optimizer|OptimWrapper=<function Adam>,
             lr:float|slice=0.001, splitter:callable=<function
             trainable_params>, path:str|Path|None=None,
             model_dir:str|Path='models', wd:float|int|None=None,
             wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95,
             0.85, 0.95), default_cbs:bool=True)

A Learner suitable for GANs.

Type Default Details
dls DataLoaders DataLoaders object for GAN data
generator nn.Module Generator model
critic nn.Module Critic model
gen_loss_func callable Generator loss function
crit_loss_func callable Critic loss function
switcher Callback | None None Callback for switching between generator and critic training, defaults to FixedGANSwitcher
gen_first bool False Whether we start with generator training
switch_eval bool True Whether the model should be set to eval mode when calculating loss
show_img bool True Whether to show example generated images during training
clip None | float None How much to clip the weights
cbs Callback | None | MutableSequence None Additional callbacks
metrics None | MutableSequence | callable None Metrics
loss_func callable | None None
opt_func Optimizer | OptimWrapper Adam
lr float | slice 0.001
splitter callable trainable_params
path str | Path | None None
model_dir str | Path models
wd float | int | None None
wd_bn_bias bool False
train_bn bool True
moms tuple (0.95, 0.85, 0.95)
default_cbs bool True

source

GANLearner.from_learners

 GANLearner.from_learners (gen_learn:Learner, crit_learn:Learner,
                           switcher:Callback|None=None,
                           weights_gen:None|MutableSequence|tuple=None,
                           gen_first:bool=False, switch_eval:bool=True,
                           show_img:bool=True, clip:None|float=None,
                           cbs:Callback|None|MutableSequence=None,
                           metrics:None|MutableSequence|callable=None,
                           loss_func:callable|None=None,
                           opt_func:Optimizer|OptimWrapper=<function
                           Adam>, lr:float|slice=0.001,
                           splitter:callable=<function trainable_params>,
                           path:str|Path|None=None,
                           model_dir:str|Path='models',
                           wd:float|int|None=None, wd_bn_bias:bool=False,
                           train_bn:bool=True, moms:tuple=(0.95, 0.85,
                           0.95), default_cbs:bool=True)

Create a GAN from learn_gen and learn_crit.


source

GANLearner.wgan

 GANLearner.wgan (dls:DataLoaders, generator:nn.Module, critic:nn.Module,
                  switcher:Callback|None=None, clip:None|float=0.01,
                  switch_eval:bool=False, gen_first:bool=False,
                  show_img:bool=True,
                  cbs:Callback|None|MutableSequence=None,
                  metrics:None|MutableSequence|callable=None,
                  loss_func:callable|None=None,
                  opt_func:Optimizer|OptimWrapper=<function Adam>,
                  lr:float|slice=0.001, splitter:callable=<function
                  trainable_params>, path:str|Path|None=None,
                  model_dir:str|Path='models', wd:float|int|None=None,
                  wd_bn_bias:bool=False, train_bn:bool=True,
                  moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)

Create a WGAN from dls, generator and critic.

from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch train_loss gen_loss crit_loss time
0 -0.815071 0.646809 -1.140522 00:38
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)