critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])
tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)
tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])GAN
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
- Freeze the generator and train the critic for one step by:
- getting one batch of true images (let’s call that
real) - generating one batch of fake images (let’s call that
fake) - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the critic with the gradients of this loss
- Freeze the critic and train the generator for one step by:
- generating one batch of fake images
- evaluate the critic on it
- return a loss that rewards positively the critic thinking those are real images
- update the weights of the generator with the gradients of this loss
The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.
Wrapping the modules
GANModule
GANModule (generator:nn.Module=None, critic:nn.Module=None, gen_mode:None|bool=False)
Wrapper around a generator and a critic to create a GAN.
| Type | Default | Details | |
|---|---|---|---|
| generator | Module | None | The generator PyTorch module |
| critic | Module | None | The discriminator PyTorch module |
| gen_mode | None | bool | False | Whether the GAN should be set to generator mode |
This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.
GANModule.switch
GANModule.switch (gen_mode:None|bool=None)
Put the module in generator mode if gen_mode is True, in critic mode otherwise.
| Type | Default | Details | |
|---|---|---|---|
| gen_mode | None | bool | None | Whether the GAN should be set to generator mode |
By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).
basic_critic
basic_critic (in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,tuple[int,int]]=1, groups:int=1, padding _mode:Literal['zeros','reflect','replicate','circular']='ze ros', device=None, dtype=None)
A basic critic for images n_channels x in_size x in_size.
| Type | Default | Details | |
|---|---|---|---|
| in_size | int | Input size for the critic (same as the output size of the generator) | |
| n_channels | int | Number of channels of the input for the critic | |
| n_features | int | 64 | Number of features used in the critic |
| n_extra_layers | int | 0 | Number of extra hidden layers in the critic |
| norm_type | NormType | NormType.Batch | Type of normalization to use in the critic |
| ks | int | 3 | |
| stride | int | 1 | |
| padding | NoneType | None | |
| bias | NoneType | None | |
| ndim | int | 2 | |
| bn_1st | bool | True | |
| act_cls | type | ReLU | |
| transpose | bool | False | |
| init | str | auto | |
| xtra | NoneType | None | |
| bias_std | float | 0.01 | |
| dilation | Union | 1 | |
| groups | int | 1 | |
| padding_mode | Literal | zeros | |
| device | NoneType | None | |
| dtype | NoneType | None | |
| Returns | nn.Sequential |
AddChannels
AddChannels (n_dim)
Add n_dim channels at the end of the input.
basic_generator
basic_generator (out_size:int, n_channels:int, in_sz:int=100, n_features:int=64, n_extra_layers:int=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,tuple[int,int]]=1, groups:int=1, padd ing_mode:Literal['zeros','reflect','replicate','circular ']='zeros', device=None, dtype=None)
A basic generator from in_sz to images n_channels x out_size x out_size.
| Type | Default | Details | |
|---|---|---|---|
| out_size | int | Output size for the generator (same as the input size for the critic) | |
| n_channels | int | Number of channels of the output of the generator | |
| in_sz | int | 100 | Size of the input noise vector for the generator |
| n_features | int | 64 | Number of features used in the generator |
| n_extra_layers | int | 0 | Number of extra hidden layers in the generator |
| ks | int | 3 | |
| stride | int | 1 | |
| padding | NoneType | None | |
| bias | NoneType | None | |
| ndim | int | 2 | |
| norm_type | NormType | NormType.Batch | |
| bn_1st | bool | True | |
| act_cls | type | ReLU | |
| transpose | bool | False | |
| init | str | auto | |
| xtra | NoneType | None | |
| bias_std | float | 0.01 | |
| dilation | Union | 1 | |
| groups | int | 1 | |
| padding_mode | Literal | zeros | |
| device | NoneType | None | |
| dtype | NoneType | None | |
| Returns | nn.Sequential |
DenseResBlock
DenseResBlock (nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,tuple[int,int]]=1, groups:int=1, paddin g_mode:Literal['zeros','reflect','replicate','circular']=' zeros', device=None, dtype=None)
Resnet block of nf features. conv_kwargs are passed to conv_layer.
| Type | Default | Details | |
|---|---|---|---|
| nf | int | Number of features | |
| norm_type | NormType | NormType.Batch | Normalization type |
| ks | int | 3 | |
| stride | int | 1 | |
| padding | NoneType | None | |
| bias | NoneType | None | |
| ndim | int | 2 | |
| bn_1st | bool | True | |
| act_cls | type | ReLU | |
| transpose | bool | False | |
| init | str | auto | |
| xtra | NoneType | None | |
| bias_std | float | 0.01 | |
| dilation | Union | 1 | |
| groups | int | 1 | |
| padding_mode | Literal | zeros | |
| device | NoneType | None | |
| dtype | NoneType | None | |
| Returns | SequentialEx |
gan_critic
gan_critic (n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)
Critic to train a GAN.
| Type | Default | Details | |
|---|---|---|---|
| n_channels | int | 3 | Number of channels of the input for the critic |
| nf | int | 128 | Number of features for the critic |
| n_blocks | int | 3 | Number of ResNet blocks within the critic |
| p | float | 0.15 | Amount of dropout in the critic |
| Returns | Sequential |
GANLoss
GANLoss (gen_loss_func:Callable, crit_loss_func:Callable, gan_model:GANModule)
Wrapper around crit_loss_func and gen_loss_func
| Type | Details | |
|---|---|---|
| gen_loss_func | Callable | Generator loss function |
| crit_loss_func | Callable | Critic loss function |
| gan_model | GANModule | The GAN model |
GANLoss.generator
GANLoss.generator (output, target)
Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output
| Details | |
|---|---|
| output | Generator outputs |
| target | Real images |
GANLoss.critic
GANLoss.critic (real_pred, input)
Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.
| Details | |
|---|---|
| real_pred | Critic predictions for real images |
| input | Input noise vector to pass into generator |
If the generator method is called, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature
def gen_loss_func(fake_pred, output, target):
to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).
If the critic method is called, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature
def crit_loss_func(real_pred, fake_pred):
where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.
AdaptiveLoss
AdaptiveLoss (crit:Callable)
Expand the target to match the output size before applying crit.
accuracy_thresh_expand
accuracy_thresh_expand (y_pred:torch.Tensor, y_true:torch.Tensor, thresh:float=0.5, sigmoid:bool=True)
Compute thresholded accuracy after expanding y_true to the size of y_pred.
Callbacks for GAN training
set_freeze_model
set_freeze_model (m:torch.nn.modules.module.Module, rg:bool)
| Type | Details | |
|---|---|---|
| m | Module | Model to freeze/unfreeze |
| rg | bool | Requires grad argument. True for freeze |
GANTrainer
GANTrainer (switch_eval:bool=False, clip:None|float=None, beta:float=0.98, gen_first:bool=False, show_img:bool=True)
Callback to handle GAN Training.
| Type | Default | Details | |
|---|---|---|---|
| switch_eval | bool | False | Whether the model should be set to eval mode when calculating loss |
| clip | None | float | None | How much to clip the weights |
| beta | float | 0.98 | Exponentially weighted smoothing of the losses beta |
| gen_first | bool | False | Whether we start with generator training |
| show_img | bool | True | Whether to show example generated images during training |
The GANTrainer is useless on its own, you need to complete it with one of the following switchers
FixedGANSwitcher
FixedGANSwitcher (n_crit:int=1, n_gen:int=1)
Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.
| Type | Default | Details | |
|---|---|---|---|
| n_crit | int | 1 | How many steps of critic training before switching to generator |
| n_gen | int | 1 | How many steps of generator training before switching to critic |
AdaptiveGANSwitcher
AdaptiveGANSwitcher (gen_thresh:None|float=None, critic_thresh:None|float=None)
Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.
| Type | Default | Details | |
|---|---|---|---|
| gen_thresh | None | float | None | Loss threshold for generator |
| critic_thresh | None | float | None | Loss threshold for critic |
GANDiscriminativeLR
GANDiscriminativeLR (mult_lr=5.0)
Callback that handles multiplying the learning rate by mult_lr for the critic.
GAN data
InvisibleTensor
InvisibleTensor (x, **kwargs)
TensorBase but show method does nothing
generate_noise
generate_noise (fn, size=100)
Generate noise vector.
| Type | Default | Details | |
|---|---|---|---|
| fn | Dummy argument so it works with DataBlock |
||
| size | int | 100 | Size of returned noise vector |
| Returns | InvisibleTensor |
We use the generate_noise function to generate noise vectors to pass into the generator for image generation.
bs = 128
size = 64dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))path = untar_data(URLs.LSUN_BEDROOMS)dls = dblock.dataloaders(path, path=path, bs=bs)dls.show_batch(max_n=16)
GAN Learner
gan_loss_from_func
gan_loss_from_func (loss_gen:Callable, loss_crit:Callable, weights_gen:None|collections.abc.MutableSequence|tupl e=None)
Define loss functions for a GAN from loss_gen and loss_crit.
| Type | Default | Details | |
|---|---|---|---|
| loss_gen | Callable | A loss function for the generator. Evaluates generator output images and target real images | |
| loss_crit | Callable | A loss function for the critic. Evaluates predictions of real and fake images. | |
| weights_gen | None | collections.abc.MutableSequence | tuple | None | Weights for the generator and critic loss function |
GANLearner
GANLearner (dls:DataLoaders, generator:nn.Module, critic:nn.Module, gen_loss_func:Callable, crit_loss_func:Callable, switcher:Callback|None=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
A Learner suitable for GANs.
| Type | Default | Details | |
|---|---|---|---|
| dls | DataLoaders | DataLoaders object for GAN data | |
| generator | nn.Module | Generator model | |
| critic | nn.Module | Critic model | |
| gen_loss_func | Callable | Generator loss function | |
| crit_loss_func | Callable | Critic loss function | |
| switcher | Callback | None | None | Callback for switching between generator and critic training, defaults to FixedGANSwitcher |
| gen_first | bool | False | Whether we start with generator training |
| switch_eval | bool | True | Whether the model should be set to eval mode when calculating loss |
| show_img | bool | True | Whether to show example generated images during training |
| clip | None | float | None | How much to clip the weights |
| cbs | Callback | None | MutableSequence | None | Additional callbacks |
| metrics | None | MutableSequence | Callable | None | Metrics |
| loss_func | Callable | None | None | |
| opt_func | Optimizer | OptimWrapper | Adam | |
| lr | float | slice | 0.001 | |
| splitter | Callable | trainable_params | |
| path | str | Path | None | None | |
| model_dir | str | Path | models | |
| wd | float | int | None | None | |
| wd_bn_bias | bool | False | |
| train_bn | bool | True | |
| moms | tuple | (0.95, 0.85, 0.95) | |
| default_cbs | bool | True |
GANLearner.from_learners
GANLearner.from_learners (gen_learn:Learner, crit_learn:Learner, switcher:Callback|None=None, weights_gen:None|MutableSequence|tuple=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Create a GAN from learn_gen and learn_crit.
| Type | Default | Details | |
|---|---|---|---|
| gen_learn | Learner | A Learner object that contains the generator |
|
| crit_learn | Learner | A Learner object that contains the critic |
|
| switcher | Callback | None | None | Callback for switching between generator and critic training, defaults to FixedGANSwitcher |
| weights_gen | None | MutableSequence | tuple | None | Weights for the generator and critic loss function |
| gen_first | bool | False | Whether we start with generator training |
| switch_eval | bool | True | Whether the model should be set to eval mode when calculating loss |
| show_img | bool | True | Whether to show example generated images during training |
| clip | None | float | None | How much to clip the weights |
| cbs | Callback | None | MutableSequence | None | Additional callbacks |
| metrics | None | MutableSequence | Callable | None | Metrics |
| loss_func | Optional | None | Loss function. Defaults to dls loss |
| opt_func | fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper | Adam | Optimization function for training |
| lr | float | slice | 0.001 | Default learning rate |
| splitter | Callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |
| path | str | pathlib.Path | None | None | Parent directory to save, load, and export models. Defaults to dls path |
| model_dir | str | pathlib.Path | models | Subdirectory to save and load models |
| wd | float | int | None | None | Default weight decay |
| wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |
| train_bn | bool | True | Train frozen normalization layers |
| moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |
| default_cbs | bool | True | Include default Callbacks |
GANLearner.wgan
GANLearner.wgan (dls:DataLoaders, generator:nn.Module, critic:nn.Module, switcher:Callback|None=None, clip:None|float=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|Callable=None, loss_func:Callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:Callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Create a WGAN from dls, generator and critic.
| Type | Default | Details | |
|---|---|---|---|
| dls | DataLoaders | DataLoaders object for GAN data | |
| generator | nn.Module | Generator model | |
| critic | nn.Module | Critic model | |
| switcher | Callback | None | None | Callback for switching between generator and critic training, defaults to FixedGANSwitcher(n_crit=5, n_gen=1) |
| clip | None | float | 0.01 | How much to clip the weights |
| switch_eval | bool | False | Whether the model should be set to eval mode when calculating loss |
| gen_first | bool | False | Whether we start with generator training |
| show_img | bool | True | Whether to show example generated images during training |
| cbs | Callback | None | MutableSequence | None | Additional callbacks |
| metrics | None | MutableSequence | Callable | None | Metrics |
| loss_func | Optional | None | Loss function. Defaults to dls loss |
| opt_func | fastai.optimizer.Optimizer | fastai.optimizer.OptimWrapper | Adam | Optimization function for training |
| lr | float | slice | 0.001 | Default learning rate |
| splitter | Callable | trainable_params | Split model into parameter groups. Defaults to one parameter group |
| path | str | pathlib.Path | None | None | Parent directory to save, load, and export models. Defaults to dls path |
| model_dir | str | pathlib.Path | models | Subdirectory to save and load models |
| wd | float | int | None | None | Default weight decay |
| wd_bn_bias | bool | False | Apply weight decay to normalization and bias parameters |
| train_bn | bool | True | Train frozen normalization layers |
| moms | tuple | (0.95, 0.85, 0.95) | Default momentum for schedulers |
| default_cbs | bool | True | Include default Callbacks |
from fastai.callback.all import *generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)learn.recorder.train_metrics=True
learn.recorder.valid_metrics=Falselearn.fit(1, 2e-4, wd=0.)/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
| epoch | train_loss | gen_loss | crit_loss | time |
|---|---|---|---|---|
| 0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)