GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:

  1. Freeze the generator and train the critic for one step by:
    • getting one batch of true images (let's call that real)
    • generating one batch of fake images (let's call that fake)
    • have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
    • update the weights of the critic with the gradients of this loss
  1. Freeze the critic and train the generator for one step by:
    • generating one batch of fake images
    • evaluate the critic on it
    • return a loss that rewards positively the critic thinking those are real images
    • update the weights of the generator with the gradients of this loss

Wrapping the modules

class GANModule[source]

GANModule(generator:nn.Module=None, critic:nn.Module=None, gen_mode:(None, bool)=False) :: Module

Wrapper around a generator and a critic to create a GAN.

Type Default Details
generator nn.Module None The generator PyTorch module
critic nn.Module None The discriminator PyTorch module
gen_mode (None, bool) False Whether the GAN should be set to generator mode

This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.

GANModule.switch[source]

GANModule.switch(gen_mode:(None, bool)=None)

Put the module in generator mode if gen_mode is True, in critic mode otherwise.

Type Default Details
gen_mode (None, bool) None Whether the GAN should be set to generator mode

By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).

basic_critic[source]

basic_critic(in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, typing.Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)

A basic critic for images n_channels x in_size x in_size.

Type Default Details
in_size int Input size for the critic (same as the output size of the generator)
n_channels int Number of channels of the input for the critic
n_features int 64 Number of features used in the critic
n_extra_layers int 0 Number of extra hidden layers in the critic
norm_type NormType NormType.Batch Type of normalization to use in the critic
Valid Keyword Arguments
ks int 3 Argument passed to ConvLayer.__init__
stride int 1 Argument passed to ConvLayer.__init__
padding None Argument passed to ConvLayer.__init__
bias None Argument passed to ConvLayer.__init__
ndim int 2 Argument passed to ConvLayer.__init__
bn_1st bool True Argument passed to ConvLayer.__init__
act_cls type ReLU Argument passed to ConvLayer.__init__
transpose bool False Argument passed to ConvLayer.__init__
init str auto Argument passed to ConvLayer.__init__
xtra None Argument passed to ConvLayer.__init__
bias_std float 0.01 Argument passed to ConvLayer.__init__
dilation typing.Union[int, typing.Tuple[int, int]] 1 Argument passed to ConvLayer.__init__
groups int 1 Argument passed to ConvLayer.__init__
padding_mode str zeros Argument passed to ConvLayer.__init__
device None Argument passed to ConvLayer.__init__
dtype None Argument passed to ConvLayer.__init__
Returns nn.Sequential

class AddChannels[source]

AddChannels(n_dim) :: Module

Add n_dim channels at the end of the input.

basic_generator[source]

basic_generator(out_size:int, n_channels:int, in_sz:int=100, n_features:int=64, n_extra_layers:int=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, typing.Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)

A basic generator from in_sz to images n_channels x out_size x out_size.

Type Default Details
out_size int Output size for the generator (same as the input size for the critic)
n_channels int Number of channels of the output of the generator
in_sz int 100 Size of the input noise vector for the generator
n_features int 64 Number of features used in the generator
n_extra_layers int 0 Number of extra hidden layers in the generator
Valid Keyword Arguments
ks int 3 Argument passed to ConvLayer.__init__
stride int 1 Argument passed to ConvLayer.__init__
padding None Argument passed to ConvLayer.__init__
bias None Argument passed to ConvLayer.__init__
ndim int 2 Argument passed to ConvLayer.__init__
norm_type (NormType.Batch, NormType.BatchZero, NormType.Weight, NormType.Spectral, NormType.Instance, NormType.InstanceZero) NormType.Batch Argument passed to ConvLayer.__init__
bn_1st bool True Argument passed to ConvLayer.__init__
act_cls type ReLU Argument passed to ConvLayer.__init__
transpose bool False Argument passed to ConvLayer.__init__
init str auto Argument passed to ConvLayer.__init__
xtra None Argument passed to ConvLayer.__init__
bias_std float 0.01 Argument passed to ConvLayer.__init__
dilation typing.Union[int, typing.Tuple[int, int]] 1 Argument passed to ConvLayer.__init__
groups int 1 Argument passed to ConvLayer.__init__
padding_mode str zeros Argument passed to ConvLayer.__init__
device None Argument passed to ConvLayer.__init__
dtype None Argument passed to ConvLayer.__init__
Returns nn.Sequential
critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])

tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)

tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])

DenseResBlock[source]

DenseResBlock(nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, typing.Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)

Resnet block of nf features. conv_kwargs are passed to conv_layer.

Type Default Details
nf int Number of features
norm_type NormType NormType.Batch Normalization type
Valid Keyword Arguments
ks int 3 Argument passed to ConvLayer.__init__
stride int 1 Argument passed to ConvLayer.__init__
padding None Argument passed to ConvLayer.__init__
bias None Argument passed to ConvLayer.__init__
ndim int 2 Argument passed to ConvLayer.__init__
bn_1st bool True Argument passed to ConvLayer.__init__
act_cls type ReLU Argument passed to ConvLayer.__init__
transpose bool False Argument passed to ConvLayer.__init__
init str auto Argument passed to ConvLayer.__init__
xtra None Argument passed to ConvLayer.__init__
bias_std float 0.01 Argument passed to ConvLayer.__init__
dilation typing.Union[int, typing.Tuple[int, int]] 1 Argument passed to ConvLayer.__init__
groups int 1 Argument passed to ConvLayer.__init__
padding_mode str zeros Argument passed to ConvLayer.__init__
device None Argument passed to ConvLayer.__init__
dtype None Argument passed to ConvLayer.__init__
Returns SequentialEx

gan_critic[source]

gan_critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)

Critic to train a GAN.

Type Default Details
n_channels int 3 Number of channels of the input for the critic
nf int 128 Number of features for the critic
n_blocks int 3 Number of ResNet blocks within the critic
p float 0.15 Amount of dropout in the critic

class GANLoss[source]

GANLoss(gen_loss_func:callable, crit_loss_func:callable, gan_model:GANModule) :: GANModule

Wrapper around crit_loss_func and gen_loss_func

Type Default Details
gen_loss_func callable Generator loss function
crit_loss_func callable Critic loss function
gan_model GANModule The GAN model

GANLoss.generator[source]

GANLoss.generator(output, target)

Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output

Type Default Details
output Generator outputs
target Real images

GANLoss.critic[source]

GANLoss.critic(real_pred, input)

Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.

Type Default Details
real_pred Critic predictions for real images
input Input noise vector to pass into generator

If the generator method is called, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature

def gen_loss_func(fake_pred, output, target):

to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).

If the critic method is called, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature

def crit_loss_func(real_pred, fake_pred):

where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.

class AdaptiveLoss[source]

AdaptiveLoss(crit:callable) :: Module

Expand the target to match the output size before applying crit.

accuracy_thresh_expand[source]

accuracy_thresh_expand(y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True)

Compute thresholded accuracy after expanding y_true to the size of y_pred.

Callbacks for GAN training

set_freeze_model[source]

set_freeze_model(m:nn.Module, rg:bool)

Type Default Details
m nn.Module Model to freeze/unfreeze
rg bool Requires grad argument. True for freeze

class GANTrainer[source]

GANTrainer(switch_eval:bool=False, clip:(None, float)=None, beta:float=0.98, gen_first:bool=False, show_img:bool=True) :: Callback

Callback to handle GAN Training.

Type Default Details
switch_eval bool False Whether the model should be set to eval mode when calculating loss
clip (None, float) None How much to clip the weights
beta float 0.98 Exponentially weighted smoothing of the losses beta
gen_first bool False Whether we start with generator training
show_img bool True Whether to show example generated images during training

class FixedGANSwitcher[source]

FixedGANSwitcher(n_crit:int=1, n_gen:int=1) :: Callback

Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.

Type Default Details
n_crit int 1 How many steps of critic training before switching to generator
n_gen int 1 How many steps of generator training before switching to critic

class AdaptiveGANSwitcher[source]

AdaptiveGANSwitcher(gen_thresh:(None, float)=None, critic_thresh:(None, float)=None) :: Callback

Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.

Type Default Details
gen_thresh (None, float) None Loss threshold for generator
critic_thresh (None, float) None Loss threshold for critic

class GANDiscriminativeLR[source]

GANDiscriminativeLR(mult_lr=5.0) :: Callback

Callback that handles multiplying the learning rate by mult_lr for the critic.

GAN data

class InvisibleTensor[source]

InvisibleTensor(x, **kwargs) :: TensorBase

TensorBase but show method does nothing

generate_noise[source]

generate_noise(fn, size=100)

Generate noise vector.

Type Default Details
fn Dummy argument so it works with DataBlock
size int 100 Size of returned noise vector

We use the generate_noise function to generate noise vectors to pass into the generator for image generation.

bs = 128
size = 64
dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
                   get_x = generate_noise,
                   get_items = get_image_files,
                   splitter = IndexSplitter([]),
                   item_tfms=Resize(size, method=ResizeMethod.Crop), 
                   batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))
path = untar_data(URLs.LSUN_BEDROOMS)
dls = dblock.dataloaders(path, path=path, bs=bs)
dls.show_batch(max_n=16)

GAN Learner

gan_loss_from_func[source]

gan_loss_from_func(loss_gen:callable, loss_crit:callable, weights_gen:(None, list, tuple)=None)

Define loss functions for a GAN from loss_gen and loss_crit.

Type Default Details
loss_gen callable A loss function for the generator. Evaluates generator output images and target real images
loss_crit callable A loss function for the critic. Evaluates predictions of real and fake images.
weights_gen (None, list, tuple) None Weights for the generator and critic loss function

class GANLearner[source]

GANLearner(dls:DataLoaders, generator:nn.Module, critic:nn.Module, gen_loss_func:callable, crit_loss_func:callable, switcher:(Callback, None)=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:(None, float)=None, cbs:(Callback, None, list)=None, metrics:(None, list, callable)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95)) :: Learner

A Learner suitable for GANs.

Type Default Details
dls DataLoaders DataLoaders object for GAN data
generator nn.Module Generator model
critic nn.Module Critic model
gen_loss_func callable Generator loss function
crit_loss_func callable Critic loss function
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher
gen_first bool False Whether we start with generator training
switch_eval bool True Whether the model should be set to eval mode when calculating loss
show_img bool True Whether to show example generated images during training
clip (None, float) None How much to clip the weights
cbs (Callback, None, list) None Additional callbacks
metrics (None, list, callable) None Metrics
Valid Keyword Arguments
loss_func None Argument passed to Learner.__init__
opt_func function Adam Argument passed to Learner.__init__
lr float 0.001 Argument passed to Learner.__init__
splitter function trainable_params Argument passed to Learner.__init__
path None Argument passed to Learner.__init__
model_dir str models Argument passed to Learner.__init__
wd None Argument passed to Learner.__init__
wd_bn_bias bool False Argument passed to Learner.__init__
train_bn bool True Argument passed to Learner.__init__
moms tuple (0.95, 0.85, 0.95) Argument passed to Learner.__init__

GANLearner.from_learners[source]

GANLearner.from_learners(gen_learn:Learner, crit_learn:Learner, switcher:(Callback, None)=None, weights_gen:(None, list, tuple)=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:(None, float)=None, cbs:(Callback, None, list)=None, metrics:(None, list, callable)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a GAN from learn_gen and learn_crit.

Type Default Details
gen_learn Learner A Learner object that contains the generator
crit_learn Learner A Learner object that contains the critic
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher
weights_gen (None, list, tuple) None Weights for the generator and critic loss function
Valid Keyword Arguments
gen_first bool False Whether we start with generator training passed to GANLearner.__init__
switch_eval bool True Whether the model should be set to eval mode when calculating loss passed to GANLearner.__init__
show_img bool True Whether to show example generated images during training passed to GANLearner.__init__
clip (None, float) None How much to clip the weights passed to GANLearner.__init__
cbs (Callback, None, list) None Additional callbacks passed to GANLearner.__init__
metrics (None, list, callable) None Metrics passed to GANLearner.__init__
loss_func None Argument passed to GANLearner.__init__
opt_func function Adam Argument passed to GANLearner.__init__
lr float 0.001 Argument passed to GANLearner.__init__
splitter function trainable_params Argument passed to GANLearner.__init__
path None Argument passed to GANLearner.__init__
model_dir str models Argument passed to GANLearner.__init__
wd None Argument passed to GANLearner.__init__
wd_bn_bias bool False Argument passed to GANLearner.__init__
train_bn bool True Argument passed to GANLearner.__init__
moms tuple (0.95, 0.85, 0.95) Argument passed to GANLearner.__init__

GANLearner.wgan[source]

GANLearner.wgan(dls:DataLoaders, generator:nn.Module, critic:nn.Module, switcher:(Callback, None)=None, clip:(None, float)=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:(Callback, None, list)=None, metrics:(None, list, callable)=None, loss_func=None, opt_func=Adam, lr=0.001, splitter=trainable_params, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95, 0.85, 0.95))

Create a WGAN from dls, generator and critic.

Type Default Details
dls DataLoaders DataLoaders object for GAN data
generator nn.Module Generator model
critic nn.Module Critic model
switcher (Callback, None) None Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1)
clip (None, float) 0.01 How much to clip the weights
switch_eval bool False Whether the model should be set to eval mode when calculating loss
Valid Keyword Arguments
gen_first bool False Whether we start with generator training passed to GANLearner.__init__
show_img bool True Whether to show example generated images during training passed to GANLearner.__init__
cbs (Callback, None, list) None Additional callbacks passed to GANLearner.__init__
metrics (None, list, callable) None Metrics passed to GANLearner.__init__
loss_func None Argument passed to GANLearner.__init__
opt_func function Adam Argument passed to GANLearner.__init__
lr float 0.001 Argument passed to GANLearner.__init__
splitter function trainable_params Argument passed to GANLearner.__init__
path None Argument passed to GANLearner.__init__
model_dir str models Argument passed to GANLearner.__init__
wd None Argument passed to GANLearner.__init__
wd_bn_bias bool False Argument passed to GANLearner.__init__
train_bn bool True Argument passed to GANLearner.__init__
moms tuple (0.95, 0.85, 0.95) Argument passed to GANLearner.__init__
from fastai.callback.all import *
generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic    = basic_critic   (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))
learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)
learn.recorder.train_metrics=True
learn.recorder.valid_metrics=False
learn.fit(1, 2e-4, wd=0.)
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
  warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch train_loss gen_loss crit_loss time
0 -0.815071 0.646809 -1.140522 00:38
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
  warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)