critic = basic_critic(64, 3)
generator = basic_generator(64, 3)
tst = GANModule(critic=critic, generator=generator)
real = torch.randn(2, 3, 64, 64)
real_p = tst(real)
test_eq(real_p.shape, [2,1])
tst.switch() #tst is now in generator mode
noise = torch.randn(2, 100)
fake = tst(noise)
test_eq(fake.shape, real.shape)
tst.switch() #tst is back in critic mode
fake_p = tst(fake)
test_eq(fake_p.shape, [2,1])GAN
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
- Freeze the generator and train the critic for one step by:
- getting one batch of true images (let’s call that
real) - generating one batch of fake images (let’s call that
fake) - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the critic with the gradients of this loss
- Freeze the critic and train the generator for one step by:
- generating one batch of fake images
- evaluate the critic on it
- return a loss that rewards positively the critic thinking those are real images
- update the weights of the generator with the gradients of this loss
The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.
Wrapping the modules
GANModule
def GANModule(
generator:nn.Module=None, # The generator PyTorch module
critic:nn.Module=None, # The discriminator PyTorch module
gen_mode:None | bool=False, # Whether the GAN should be set to generator mode
):
Wrapper around a generator and a critic to create a GAN.
This is just a shell to contain the two models. When called, it will either delegate the input to the generator or the critic depending of the value of gen_mode.
GANModule.switch
def switch(
gen_mode:None | bool=None, # Whether the GAN should be set to generator mode
):
Put the module in generator mode if gen_mode is True, in critic mode otherwise.
By default (leaving gen_mode to None), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).
basic_critic
def basic_critic(
in_size:int, # Input size for the critic (same as the output size of the generator)
n_channels:int, # Number of channels of the input for the critic
n_features:int=64, # Number of features used in the critic
n_extra_layers:int=0, # Number of extra hidden layers in the critic
norm_type:NormType=<NormType.Batch: 1>, # Type of normalization to use in the critic
ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2, bn_1st:bool=True,
act_cls:type=ReLU, transpose:bool=False, init:str='auto', xtra:NoneType=None, bias_std:float=0.01,
dilation:Union=1, groups:int=1, padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->nn.Sequential:
A basic critic for images n_channels x in_size x in_size.
AddChannels
def AddChannels(
n_dim
):
Add n_dim channels at the end of the input.
basic_generator
def basic_generator(
out_size:int, # Output size for the generator (same as the input size for the critic)
n_channels:int, # Number of channels of the output of the generator
in_sz:int=100, # Size of the input noise vector for the generator
n_features:int=64, # Number of features used in the generator
n_extra_layers:int=0, # Number of extra hidden layers in the generator
ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2,
norm_type:NormType=<NormType.Batch: 1>, bn_1st:bool=True, act_cls:type=ReLU, transpose:bool=False,
init:str='auto', xtra:NoneType=None, bias_std:float=0.01, dilation:Union=1, groups:int=1,
padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->nn.Sequential:
A basic generator from in_sz to images n_channels x out_size x out_size.
DenseResBlock
def DenseResBlock(
nf:int, # Number of features
norm_type:NormType=<NormType.Batch: 1>, # Normalization type
ks:int=3, stride:int=1, padding:NoneType=None, bias:NoneType=None, ndim:int=2, bn_1st:bool=True,
act_cls:type=ReLU, transpose:bool=False, init:str='auto', xtra:NoneType=None, bias_std:float=0.01,
dilation:Union=1, groups:int=1, padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
)->SequentialEx:
Resnet block of nf features. conv_kwargs are passed to conv_layer.
gan_critic
def gan_critic(
n_channels:int=3, # Number of channels of the input for the critic
nf:int=128, # Number of features for the critic
n_blocks:int=3, # Number of ResNet blocks within the critic
p:float=0.15, # Amount of dropout in the critic
)->nn.Sequential:
Critic to train a GAN.
GANLoss
def GANLoss(
gen_loss_func:Callable, # Generator loss function
crit_loss_func:Callable, # Critic loss function
gan_model:GANModule, # The GAN model
):
Wrapper around crit_loss_func and gen_loss_func
GANLoss.generator
def generator(
output, # Generator outputs
target, # Real images
):
Evaluate the output with the critic then uses self.gen_loss_func to evaluate how well the critic was fooled by output
GANLoss.critic
def critic(
real_pred, # Critic predictions for real images
input, # Input noise vector to pass into generator
):
Create some fake_pred with the generator from input and compare them to real_pred in self.crit_loss_func.
If the generator method is called, this loss function expects the output of the generator and some target (a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func. This loss function has the following signature
def gen_loss_func(fake_pred, output, target):
to be able to combine the output of the critic on output (which the first argument fake_pred) with output and target (if you want to mix the GAN loss with other losses for instance).
If the critic method is called, this loss function expects the real_pred given by the critic and some input (the noise fed to the generator). It will evaluate the critic using crit_loss_func. This loss function has the following signature
def crit_loss_func(real_pred, fake_pred):
where real_pred is the output of the critic on a batch of real images and fake_pred is generated from the noise using the generator.
AdaptiveLoss
def AdaptiveLoss(
crit:Callable
):
Expand the target to match the output size before applying crit.
accuracy_thresh_expand
def accuracy_thresh_expand(
y_pred:Tensor, y_true:Tensor, thresh:float=0.5, sigmoid:bool=True
):
Compute thresholded accuracy after expanding y_true to the size of y_pred.
Callbacks for GAN training
set_freeze_model
def set_freeze_model(
m:nn.Module, # Model to freeze/unfreeze
rg:bool, # `Requires grad` argument. `True` for freeze
):
GANTrainer
def GANTrainer(
switch_eval:bool=False, # Whether the model should be set to eval mode when calculating loss
clip:None | float=None, # How much to clip the weights
beta:float=0.98, # Exponentially weighted smoothing of the losses `beta`
gen_first:bool=False, # Whether we start with generator training
show_img:bool=True, # Whether to show example generated images during training
):
Callback to handle GAN Training.
The GANTrainer is useless on its own, you need to complete it with one of the following switchers
FixedGANSwitcher
def FixedGANSwitcher(
n_crit:int=1, # How many steps of critic training before switching to generator
n_gen:int=1, # How many steps of generator training before switching to critic
):
Switcher to do n_crit iterations of the critic then n_gen iterations of the generator.
AdaptiveGANSwitcher
def AdaptiveGANSwitcher(
gen_thresh:None | float=None, # Loss threshold for generator
critic_thresh:None | float=None, # Loss threshold for critic
):
Switcher that goes back to generator/critic when the loss goes below gen_thresh/crit_thresh.
GANDiscriminativeLR
def GANDiscriminativeLR(
mult_lr:float=5.0
):
Callback that handles multiplying the learning rate by mult_lr for the critic.
GAN data
InvisibleTensor
def InvisibleTensor(
args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):
TensorBase but show method does nothing
generate_noise
def generate_noise(
fn, # Dummy argument so it works with [`DataBlock`](https://docs.fast.ai/data.block.html#datablock)
size:int=100, # Size of returned noise vector
)->InvisibleTensor:
Generate noise vector.
We use the generate_noise function to generate noise vectors to pass into the generator for image generation.
bs = 128
size = 64dblock = DataBlock(blocks = (TransformBlock, ImageBlock),
get_x = generate_noise,
get_items = get_image_files,
splitter = IndexSplitter([]),
item_tfms=Resize(size, method=ResizeMethod.Crop),
batch_tfms = Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])))path = untar_data(URLs.LSUN_BEDROOMS)dls = dblock.dataloaders(path, path=path, bs=bs)dls.show_batch(max_n=16)
GAN Learner
gan_loss_from_func
def gan_loss_from_func(
loss_gen:Callable, # A loss function for the generator. Evaluates generator output images and target real images
loss_crit:Callable, # A loss function for the critic. Evaluates predictions of real and fake images.
weights_gen:None | MutableSequence | tuple=None, # Weights for the generator and critic loss function
):
Define loss functions for a GAN from loss_gen and loss_crit.
GANLearner
def GANLearner(
dls:DataLoaders, # DataLoaders object for GAN data
generator:nn.Module, # Generator model
critic:nn.Module, # Critic model
gen_loss_func:Callable, # Generator loss function
crit_loss_func:Callable, # Critic loss function
switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to [`FixedGANSwitcher`](https://docs.fast.ai/vision.gan.html#fixedganswitcher)
gen_first:bool=False, # Whether we start with generator training
switch_eval:bool=True, # Whether the model should be set to eval mode when calculating loss
show_img:bool=True, # Whether to show example generated images during training
clip:None | float=None, # How much to clip the weights
cbs:Callback | None | MutableSequence=None, # Additional callbacks
metrics:None | MutableSequence | Callable=None, # Metrics
loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
lr:float | slice=0.001, # Default learning rate
splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
model_dir:str | Path='models', # Subdirectory to save and load models
wd:float | int | None=None, # Default weight decay
wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
train_bn:bool=True, # Train frozen normalization layers
moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):
A Learner suitable for GANs.
GANLearner.from_learners
def from_learners(
gen_learn:Learner, # A [`Learner`](https://docs.fast.ai/learner.html#learner) object that contains the generator
crit_learn:Learner, # A [`Learner`](https://docs.fast.ai/learner.html#learner) object that contains the critic
switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to [`FixedGANSwitcher`](https://docs.fast.ai/vision.gan.html#fixedganswitcher)
weights_gen:None | MutableSequence | tuple=None, # Weights for the generator and critic loss function
gen_first:bool=False, # Whether we start with generator training
switch_eval:bool=True, # Whether the model should be set to eval mode when calculating loss
show_img:bool=True, # Whether to show example generated images during training
clip:None | float=None, # How much to clip the weights
cbs:Callback | None | MutableSequence=None, # Additional callbacks
metrics:None | MutableSequence | Callable=None, # Metrics
loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
lr:float | slice=0.001, # Default learning rate
splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
model_dir:str | Path='models', # Subdirectory to save and load models
wd:float | int | None=None, # Default weight decay
wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
train_bn:bool=True, # Train frozen normalization layers
moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):
Create a GAN from learn_gen and learn_crit.
GANLearner.wgan
def wgan(
dls:DataLoaders, # DataLoaders object for GAN data
generator:nn.Module, # Generator model
critic:nn.Module, # Critic model
switcher:Callback | None=None, # Callback for switching between generator and critic training, defaults to `FixedGANSwitcher(n_crit=5, n_gen=1)`
clip:None | float=0.01, # How much to clip the weights
switch_eval:bool=False, # Whether the model should be set to eval mode when calculating loss
gen_first:bool=False, # Whether we start with generator training
show_img:bool=True, # Whether to show example generated images during training
cbs:Callback | None | MutableSequence=None, # Additional callbacks
metrics:None | MutableSequence | Callable=None, # Metrics
loss_func:Callable | None=None, # Loss function. Defaults to `dls` loss
opt_func:Optimizer | OptimWrapper=Adam, # Optimization function for training
lr:float | slice=0.001, # Default learning rate
splitter:Callable=trainable_params, # Split model into parameter groups. Defaults to one parameter group
path:str | Path | None=None, # Parent directory to save, load, and export models. Defaults to `dls` `path`
model_dir:str | Path='models', # Subdirectory to save and load models
wd:float | int | None=None, # Default weight decay
wd_bn_bias:bool=False, # Apply weight decay to normalization and bias parameters
train_bn:bool=True, # Train frozen normalization layers
moms:tuple=(0.95, 0.85, 0.95), # Default momentum for schedulers
default_cbs:bool=True, # Include default [`Callback`](https://docs.fast.ai/callback.core.html#callback)s
):
Create a WGAN from dls, generator and critic.
from fastai.callback.all import *generator = basic_generator(64, n_channels=3, n_extra_layers=1)
critic = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2))learn = GANLearner.wgan(dls, generator, critic, opt_func = RMSProp)learn.recorder.train_metrics=True
learn.recorder.valid_metrics=Falselearn.fit(1, 2e-4, wd=0.)/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
| epoch | train_loss | gen_loss | crit_loss | time |
|---|---|---|---|---|
| 0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
learn.show_results(max_n=9, ds_idx=0)