= basic_critic(64, 3)
critic = basic_generator(64, 3)
generator = GANModule(critic=critic, generator=generator)
tst = torch.randn(2, 3, 64, 64)
real = tst(real)
real_p 2,1])
test_eq(real_p.shape, [
#tst is now in generator mode
tst.switch() = torch.randn(2, 100)
noise = tst(noise)
fake
test_eq(fake.shape, real.shape)
#tst is back in critic mode
tst.switch() = tst(fake)
fake_p 2,1]) test_eq(fake_p.shape, [
GAN
GAN stands for Generative Adversarial Nets and were invented by Ian Goodfellow. The concept is that we train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in a dataset, and the critic will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually a probability, 0. for fake images and 1. for real ones).
We train them against each other in the sense that at each step (more or less), we:
- Freeze the generator and train the critic for one step by:
- getting one batch of true images (let’s call that
real
) - generating one batch of fake images (let’s call that
fake
) - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
- update the weights of the critic with the gradients of this loss
- Freeze the critic and train the generator for one step by:
- generating one batch of fake images
- evaluate the critic on it
- return a loss that rewards positively the critic thinking those are real images
- update the weights of the generator with the gradients of this loss
The fastai library provides support for training GANs through the GANTrainer, but doesn’t include more than basic models.
Wrapping the modules
GANModule
GANModule (generator:nn.Module=None, critic:nn.Module=None, gen_mode:None|bool=False)
Wrapper around a generator
and a critic
to create a GAN.
Type | Default | Details | |
---|---|---|---|
generator | Module | None | The generator PyTorch module |
critic | Module | None | The discriminator PyTorch module |
gen_mode | None | bool | False | Whether the GAN should be set to generator mode |
This is just a shell to contain the two models. When called, it will either delegate the input to the generator
or the critic
depending of the value of gen_mode
.
GANModule.switch
GANModule.switch (gen_mode:None|bool=None)
Put the module in generator mode if gen_mode
is True
, in critic mode otherwise.
Type | Default | Details | |
---|---|---|---|
gen_mode | None | bool | None | Whether the GAN should be set to generator mode |
By default (leaving gen_mode
to None
), this will put the module in the other mode (critic mode if it was in generator mode and vice versa).
basic_critic
basic_critic (in_size:int, n_channels:int, n_features:int=64, n_extra_layers:int=0, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
A basic critic for images n_channels
x in_size
x in_size
.
Type | Default | Details | |
---|---|---|---|
in_size | int | Input size for the critic (same as the output size of the generator) | |
n_channels | int | Number of channels of the input for the critic | |
n_features | int | 64 | Number of features used in the critic |
n_extra_layers | int | 0 | Number of extra hidden layers in the critic |
norm_type | NormType | NormType.Batch | Type of normalization to use in the critic |
ks | int | 3 | |
stride | int | 1 | |
padding | NoneType | None | |
bias | NoneType | None | |
ndim | int | 2 | |
bn_1st | bool | True | |
act_cls | type | ReLU | |
transpose | bool | False | |
init | str | auto | |
xtra | NoneType | None | |
bias_std | float | 0.01 | |
dilation | Union | 1 | |
groups | int | 1 | |
padding_mode | str | zeros | TODO: refine this type |
device | NoneType | None | |
dtype | NoneType | None | |
Returns | nn.Sequential |
AddChannels
AddChannels (n_dim)
Add n_dim
channels at the end of the input.
basic_generator
basic_generator (out_size:int, n_channels:int, in_sz:int=100, n_features:int=64, n_extra_layers:int=0, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
A basic generator from in_sz
to images n_channels
x out_size
x out_size
.
Type | Default | Details | |
---|---|---|---|
out_size | int | Output size for the generator (same as the input size for the critic) | |
n_channels | int | Number of channels of the output of the generator | |
in_sz | int | 100 | Size of the input noise vector for the generator |
n_features | int | 64 | Number of features used in the generator |
n_extra_layers | int | 0 | Number of extra hidden layers in the generator |
ks | int | 3 | |
stride | int | 1 | |
padding | NoneType | None | |
bias | NoneType | None | |
ndim | int | 2 | |
norm_type | NormType | NormType.Batch | |
bn_1st | bool | True | |
act_cls | type | ReLU | |
transpose | bool | False | |
init | str | auto | |
xtra | NoneType | None | |
bias_std | float | 0.01 | |
dilation | Union | 1 | |
groups | int | 1 | |
padding_mode | str | zeros | TODO: refine this type |
device | NoneType | None | |
dtype | NoneType | None | |
Returns | nn.Sequential |
DenseResBlock
DenseResBlock (nf:int, norm_type:NormType=<NormType.Batch: 1>, ks=3, stride=1, padding=None, bias=None, ndim=2, bn_1st=True, act_cls=<class 'torch.nn.modules.activation.ReLU'>, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, padding_mode:str='zeros', device=None, dtype=None)
Resnet block of nf
features. conv_kwargs
are passed to conv_layer
.
Type | Default | Details | |
---|---|---|---|
nf | int | Number of features | |
norm_type | NormType | NormType.Batch | Normalization type |
ks | int | 3 | |
stride | int | 1 | |
padding | NoneType | None | |
bias | NoneType | None | |
ndim | int | 2 | |
bn_1st | bool | True | |
act_cls | type | ReLU | |
transpose | bool | False | |
init | str | auto | |
xtra | NoneType | None | |
bias_std | float | 0.01 | |
dilation | Union | 1 | |
groups | int | 1 | |
padding_mode | str | zeros | TODO: refine this type |
device | NoneType | None | |
dtype | NoneType | None | |
Returns | SequentialEx |
gan_critic
gan_critic (n_channels:int=3, nf:int=128, n_blocks:int=3, p:float=0.15)
Critic to train a GAN
.
Type | Default | Details | |
---|---|---|---|
n_channels | int | 3 | Number of channels of the input for the critic |
nf | int | 128 | Number of features for the critic |
n_blocks | int | 3 | Number of ResNet blocks within the critic |
p | float | 0.15 | Amount of dropout in the critic |
Returns | Sequential |
GANLoss
GANLoss (gen_loss_func:callable, crit_loss_func:callable, gan_model:GANModule)
Wrapper around crit_loss_func
and gen_loss_func
Type | Details | |
---|---|---|
gen_loss_func | callable | Generator loss function |
crit_loss_func | callable | Critic loss function |
gan_model | GANModule | The GAN model |
GANLoss.generator
GANLoss.generator (output, target)
Evaluate the output
with the critic then uses self.gen_loss_func
to evaluate how well the critic was fooled by output
Details | |
---|---|
output | Generator outputs |
target | Real images |
GANLoss.critic
GANLoss.critic (real_pred, input)
Create some fake_pred
with the generator from input
and compare them to real_pred
in self.crit_loss_func
.
Details | |
---|---|
real_pred | Critic predictions for real images |
input | Input noise vector to pass into generator |
If the generator
method is called, this loss function expects the output
of the generator and some target
(a batch of real images). It will evaluate if the generator successfully fooled the critic using gen_loss_func
. This loss function has the following signature
def gen_loss_func(fake_pred, output, target):
to be able to combine the output of the critic on output
(which the first argument fake_pred
) with output
and target
(if you want to mix the GAN loss with other losses for instance).
If the critic
method is called, this loss function expects the real_pred
given by the critic and some input
(the noise fed to the generator). It will evaluate the critic using crit_loss_func
. This loss function has the following signature
def crit_loss_func(real_pred, fake_pred):
where real_pred
is the output of the critic on a batch of real images and fake_pred
is generated from the noise using the generator.
AdaptiveLoss
AdaptiveLoss (crit:callable)
Expand the target
to match the output
size before applying crit
.
accuracy_thresh_expand
accuracy_thresh_expand (y_pred:torch.Tensor, y_true:torch.Tensor, thresh:float=0.5, sigmoid:bool=True)
Compute thresholded accuracy after expanding y_true
to the size of y_pred
.
Callbacks for GAN training
set_freeze_model
set_freeze_model (m:torch.nn.modules.module.Module, rg:bool)
Type | Details | |
---|---|---|
m | Module | Model to freeze/unfreeze |
rg | bool | Requires grad argument. True for freeze |
GANTrainer
GANTrainer (switch_eval:bool=False, clip:None|float=None, beta:float=0.98, gen_first:bool=False, show_img:bool=True)
Callback to handle GAN Training.
Type | Default | Details | |
---|---|---|---|
switch_eval | bool | False | Whether the model should be set to eval mode when calculating loss |
clip | None | float | None | How much to clip the weights |
beta | float | 0.98 | Exponentially weighted smoothing of the losses beta |
gen_first | bool | False | Whether we start with generator training |
show_img | bool | True | Whether to show example generated images during training |
The GANTrainer is useless on its own, you need to complete it with one of the following switchers
FixedGANSwitcher
FixedGANSwitcher (n_crit:int=1, n_gen:int=1)
Switcher to do n_crit
iterations of the critic then n_gen
iterations of the generator.
Type | Default | Details | |
---|---|---|---|
n_crit | int | 1 | How many steps of critic training before switching to generator |
n_gen | int | 1 | How many steps of generator training before switching to critic |
AdaptiveGANSwitcher
AdaptiveGANSwitcher (gen_thresh:None|float=None, critic_thresh:None|float=None)
Switcher that goes back to generator/critic when the loss goes below gen_thresh
/crit_thresh
.
Type | Default | Details | |
---|---|---|---|
gen_thresh | None | float | None | Loss threshold for generator |
critic_thresh | None | float | None | Loss threshold for critic |
GANDiscriminativeLR
GANDiscriminativeLR (mult_lr=5.0)
Callback
that handles multiplying the learning rate by mult_lr
for the critic.
GAN data
InvisibleTensor
InvisibleTensor (x, **kwargs)
TensorBase but show method does nothing
generate_noise
generate_noise (fn, size=100)
Generate noise vector.
Type | Default | Details | |
---|---|---|---|
fn | Dummy argument so it works with DataBlock |
||
size | int | 100 | Size of returned noise vector |
Returns | InvisibleTensor |
We use the generate_noise
function to generate noise vectors to pass into the generator for image generation.
= 128
bs = 64 size
= DataBlock(blocks = (TransformBlock, ImageBlock),
dblock = generate_noise,
get_x = get_image_files,
get_items = IndexSplitter([]),
splitter =Resize(size, method=ResizeMethod.Crop),
item_tfms= Normalize.from_stats(torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5]))) batch_tfms
= untar_data(URLs.LSUN_BEDROOMS) path
= dblock.dataloaders(path, path=path, bs=bs) dls
=16) dls.show_batch(max_n
GAN Learner
gan_loss_from_func
gan_loss_from_func (loss_gen:<built-infunctioncallable>, loss_crit:<built-infunctioncallable>, weights_gen:Non e|collections.abc.MutableSequence|tuple=None)
Define loss functions for a GAN from loss_gen
and loss_crit
.
Type | Default | Details | |
---|---|---|---|
loss_gen | callable | A loss function for the generator. Evaluates generator output images and target real images | |
loss_crit | callable | A loss function for the critic. Evaluates predictions of real and fake images. | |
weights_gen | None | collections.abc.MutableSequence | tuple | None | Weights for the generator and critic loss function |
GANLearner
GANLearner (dls:DataLoaders, generator:nn.Module, critic:nn.Module, gen_loss_func:callable, crit_loss_func:callable, switcher:Callback|None=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|callable=None, loss_func:callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
A Learner
suitable for GANs.
Type | Default | Details | |
---|---|---|---|
dls | DataLoaders | DataLoaders object for GAN data | |
generator | nn.Module | Generator model | |
critic | nn.Module | Critic model | |
gen_loss_func | callable | Generator loss function | |
crit_loss_func | callable | Critic loss function | |
switcher | Callback | None | None | Callback for switching between generator and critic training, defaults to FixedGANSwitcher |
gen_first | bool | False | Whether we start with generator training |
switch_eval | bool | True | Whether the model should be set to eval mode when calculating loss |
show_img | bool | True | Whether to show example generated images during training |
clip | None | float | None | How much to clip the weights |
cbs | Callback | None | MutableSequence | None | Additional callbacks |
metrics | None | MutableSequence | callable | None | Metrics |
loss_func | callable | None | None | |
opt_func | Optimizer | OptimWrapper | Adam | |
lr | float | slice | 0.001 | |
splitter | callable | trainable_params | |
path | str | Path | None | None | |
model_dir | str | Path | models | |
wd | float | int | None | None | |
wd_bn_bias | bool | False | |
train_bn | bool | True | |
moms | tuple | (0.95, 0.85, 0.95) | |
default_cbs | bool | True |
GANLearner.from_learners
GANLearner.from_learners (gen_learn:Learner, crit_learn:Learner, switcher:Callback|None=None, weights_gen:None|MutableSequence|tuple=None, gen_first:bool=False, switch_eval:bool=True, show_img:bool=True, clip:None|float=None, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|callable=None, loss_func:callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Create a GAN from learn_gen
and learn_crit
.
GANLearner.wgan
GANLearner.wgan (dls:DataLoaders, generator:nn.Module, critic:nn.Module, switcher:Callback|None=None, clip:None|float=0.01, switch_eval:bool=False, gen_first:bool=False, show_img:bool=True, cbs:Callback|None|MutableSequence=None, metrics:None|MutableSequence|callable=None, loss_func:callable|None=None, opt_func:Optimizer|OptimWrapper=<function Adam>, lr:float|slice=0.001, splitter:callable=<function trainable_params>, path:str|Path|None=None, model_dir:str|Path='models', wd:float|int|None=None, wd_bn_bias:bool=False, train_bn:bool=True, moms:tuple=(0.95, 0.85, 0.95), default_cbs:bool=True)
Create a WGAN from dls
, generator
and critic
.
from fastai.callback.all import *
= basic_generator(64, n_channels=3, n_extra_layers=1)
generator = basic_critic (64, n_channels=3, n_extra_layers=1, act_cls=partial(nn.LeakyReLU, negative_slope=0.2)) critic
= GANLearner.wgan(dls, generator, critic, opt_func = RMSProp) learn
=True
learn.recorder.train_metrics=False learn.recorder.valid_metrics
1, 2e-4, wd=0.) learn.fit(
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (generator) that exists in the learner. Use `self.learn.generator` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (critic) that exists in the learner. Use `self.learn.critic` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
/home/tmabraham/git/fastai/fastai/callback/core.py:52: UserWarning: You are shadowing an attribute (gen_mode) that exists in the learner. Use `self.learn.gen_mode` to avoid this
warn(f"You are shadowing an attribute ({name}) that exists in the learner. Use `self.learn.{name}` to avoid this")
epoch | train_loss | gen_loss | crit_loss | time |
---|---|---|---|---|
0 | -0.815071 | 0.646809 | -1.140522 | 00:38 |
/home/tmabraham/anaconda3/envs/fastai/lib/python3.7/site-packages/fastprogress/fastprogress.py:74: UserWarning: Your generator is empty.
warn("Your generator is empty.")
=9, ds_idx=0) learn.show_results(max_n