from nbdev.cli import *
Data core
The classes here provide functionality for applying a list of transforms to a set of items (TfmdLists
, Datasets
) or a DataLoader
(TfmdDl
) as well as the base class used to gather the data for model training: DataLoaders
.
show_batch
is a type-dispatched function that is responsible for showing decoded samples
. x
and y
are the input and the target in the batch to be shown, and are passed along to dispatch on their types. There is a different implementation of show_batch
if x
is a TensorImage
or a TensorText
for instance (see vision.core or text.data for more details). ctxs
can be passed but the function is responsible to create them if necessary. kwargs
depend on the specific implementation.
show_results
is a type-dispatched function that is responsible for showing decoded samples
and their corresponding outs
. Like in show_batch
, x
and y
are the input and the target in the batch to be shown, and are passed along to dispatch on their types. ctxs
can be passed but the function is responsible to create them if necessary. kwargs
depend on the specific implementation.
TfmdDL
TfmdDL (dataset, bs:int=64, shuffle:bool=False, num_workers:int=None, verbose:bool=False, do_setup:bool=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, pin_memory_device='', wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)
Transformed DataLoader
Type | Default | Details | |
---|---|---|---|
dataset | Map- or iterable-style dataset from which to load the data | ||
bs | int | 64 | Size of batch |
shuffle | bool | False | Whether to shuffle data |
num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |
verbose | bool | False | Whether to print verbose logs |
do_setup | bool | True | Whether to run setup() for batch transform(s) |
pin_memory | bool | False | |
timeout | int | 0 | |
batch_size | NoneType | None | |
drop_last | bool | False | |
indexed | NoneType | None | |
n | NoneType | None | |
device | NoneType | None | |
persistent_workers | bool | False | |
pin_memory_device | str | ||
wif | NoneType | None | |
before_iter | NoneType | None | |
after_item | NoneType | None | |
before_batch | NoneType | None | |
after_batch | NoneType | None | |
after_iter | NoneType | None | |
create_batches | NoneType | None | |
create_item | NoneType | None | |
create_batch | NoneType | None | |
retain | NoneType | None | |
get_idxs | NoneType | None | |
sample | NoneType | None | |
shuffle_fn | NoneType | None | |
do_batch | NoneType | None |
A TfmdDL
is a DataLoader
that creates Pipeline
from a list of Transform
s for the callbacks after_item
, before_batch
and after_batch
. As a result, it can decode or show a processed batch
.
class _Category(int, ShowTitle): pass
#Test retain type
class NegTfm(Transform):
def encodes(self, x): return torch.neg(x)
def decodes(self, x): return torch.neg(x)
= TfmdDL([(TensorImage([1]),)] * 4, after_batch=NegTfm(), bs=4, num_workers=4)
tdl = tdl.one_batch()
b type(b[0]), TensorImage)
test_eq(= (tensor([1.,1.,1.,1.]),)
b type(tdl.decode_batch(b)[0][0]), TensorImage) test_eq(
class A(Transform):
def encodes(self, x): return x
def decodes(self, x): return TitledInt(x)
@Transform
def f(x)->None: return fastuple((x,x))
= torch.arange(50)
start 2), fastuple((2,2))) test_eq_type(f(
= A()
a = TfmdDL(start, after_item=lambda x: (a(x), f(x)), bs=4)
tdl = tdl.one_batch()
x,y type(y), fastuple)
test_eq(
= tdl.decode_batch((x,y))
s type(s[0][1]), fastuple) test_eq(
= TfmdDL(torch.arange(0,50), after_item=A(), after_batch=NegTfm(), bs=4)
tdl 0], start[0])
test_eq(tdl.dataset[len(tdl), (50-1)//4+1)
test_eq(4)
test_eq(tdl.bs, '0\n1\n2\n3')
test_stdout(tdl.show_batch, =True), '0\n0\n0\n0') test_stdout(partial(tdl.show_batch, unique
class B(Transform):
= 'a'
parameters def __init__(self): self.a = torch.tensor(0.)
def encodes(self, x): x
= TfmdDL([(TensorImage([1]),)] * 4, after_batch=B(), bs=4)
tdl 0].a.device, torch.device('cpu'))
test_eq(tdl.after_batch.fs[
tdl.to(default_device())0].a.device, default_device()) test_eq(tdl.after_batch.fs[
Methods
DataLoader.one_batch
DataLoader.one_batch ()
Return one batch from DataLoader
.
= NegTfm()
tfm = TfmdDL(start, after_batch=tfm, bs=4) tdl
= tdl.one_batch()
b 0,-1,-2,-3]), b) test_eq(tensor([
TfmdDL.decode
TfmdDL.decode (b)
Decode b
using tfms
Details | |
---|---|
b | Batch to decode |
0,1,2,3)) test_eq(tdl.decode(b), tensor(
TfmdDL.decode_batch
TfmdDL.decode_batch (b, max_n:int=9, full:bool=True)
Decode b
entirely
Type | Default | Details | |
---|---|---|---|
b | Batch to decode | ||
max_n | int | 9 | Maximum number of items to decode |
full | bool | True | Whether to decode all transforms. If False , decode up to the point the item knows how to show itself |
0,1,2,3]) test_eq(tdl.decode_batch(b), [
TfmdDL.show_batch
TfmdDL.show_batch (b=None, max_n:int=9, ctxs=None, show:bool=True, unique:bool=False, **kwargs)
Show b
(defaults to one_batch
), a list of lists of pipeline outputs (i.e. output of a DataLoader
)
Type | Default | Details | |
---|---|---|---|
b | NoneType | None | Batch to show |
max_n | int | 9 | Maximum number of items to show |
ctxs | NoneType | None | List of ctx objects to show data. Could be matplotlib axis, DataFrame etc |
show | bool | True | Whether to display data |
unique | bool | False | Whether to show only one |
kwargs |
DataLoader.to
DataLoader.to (device)
Put self and its transforms state on device
DataLoaders
DataLoaders (*loaders, path:str|pathlib.Path='.', device=None)
Basic wrapper around several DataLoader
s.
= DataLoaders(tdl,tdl)
dls = dls.train.one_batch()
x = first(tdl)
x2
test_eq(x,x2)= dls.one_batch()
x2 test_eq(x,x2)
Multiple transforms can by added to multiple dataloaders using Dataloaders.add_tfms
. You can specify the dataloaders by list of names dls.add_tfms(...,'valid',...)
or by index dls.add_tfms(...,1,....)
, by default transforms are added to all dataloaders. event
is a required argument and determined when the transform will be run, for more information on events please refer to TfmdDL
. tfms
is a list of Transform
, and is a required argument.
class _TestTfm(Transform):
def encodes(self, o): return torch.ones_like(o)
def decodes(self, o): return o
= TfmdDL(start, bs=4),TfmdDL(start, bs=4)
tdl1,tdl2 = DataLoaders(tdl1,tdl2)
dls2 'after_batch',['valid'])
dls2.add_tfms([_TestTfm()],'after_batch',[1])
dls2.add_tfms([_TestTfm()], dls2.train.after_batch,dls2.valid.after_batch,
(Pipeline: , Pipeline: _TestTfm -> _TestTfm)
class _T(Transform):
def encodes(self, o): return -o
class _T2(Transform):
def encodes(self, o): return o/2
#test tfms are applied on both traind and valid dl
= DataLoaders.from_dsets([1,], [5,], bs=1, after_item=_T, after_batch=_T2)
dls_from_ds = first(dls_from_ds.train)
b -.5]))
test_eq(b, tensor([= first(dls_from_ds.valid)
b -2.5])) test_eq(b, tensor([
Methods
DataLoaders.__getitem__
DataLoaders.__getitem__ (i)
Retrieve DataLoader
at i
(0
is training, 1
is validation)
x2
tensor([ 0, -1, -2, -3])
= dls[0].one_batch()
x2 test_eq(x,x2)
DataLoaders.train
DataLoaders.train (x)
partial(func, args, **keywords) - new function with partial application of the given arguments and keywords.*
DataLoaders.valid
DataLoaders.valid (x)
partial(func, args, **keywords) - new function with partial application of the given arguments and keywords.*
DataLoaders.train_ds
DataLoaders.train_ds (x)
partial(func, args, **keywords) - new function with partial application of the given arguments and keywords.*
DataLoaders.valid_ds
DataLoaders.valid_ds (x)
partial(func, args, **keywords) - new function with partial application of the given arguments and keywords.*
FilteredBase
FilteredBase (*args, dl_type=None, **kwargs)
Base class for lists with subsets
FilteredBase.dataloaders
FilteredBase.dataloaders (bs:int=64, shuffle_train:bool=None, shuffle:bool=True, val_shuffle:bool=False, n:int=None, path:str|Path='.', dl_type:TfmdDL=None, dl_kwargs:list=None, device:torch.device=None, drop_last:bool=None, val_bs:int=None, num_workers:int=None, verbose:bool=False, do_setup:bool=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, pin_memory_device='', wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)
Type | Default | Details | |
---|---|---|---|
bs | int | 64 | Batch size |
shuffle_train | bool | None | (Deprecated, use shuffle ) Shuffle training DataLoader |
shuffle | bool | True | Shuffle training DataLoader |
val_shuffle | bool | False | Shuffle validation DataLoader |
n | int | None | Size of Datasets used to create DataLoader |
path | str | Path | . | Path to put in DataLoaders |
dl_type | TfmdDL | None | Type of DataLoader |
dl_kwargs | list | None | List of kwargs to pass to individual DataLoader s |
device | torch.device | None | Device to put DataLoaders |
drop_last | bool | None | Drop last incomplete batch, defaults to shuffle |
val_bs | int | None | Validation batch size, defaults to bs |
num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |
verbose | bool | False | Whether to print verbose logs |
do_setup | bool | True | Whether to run setup() for batch transform(s) |
pin_memory | bool | False | |
timeout | int | 0 | |
batch_size | NoneType | None | |
indexed | NoneType | None | |
persistent_workers | bool | False | |
pin_memory_device | str | ||
wif | NoneType | None | |
before_iter | NoneType | None | |
after_item | NoneType | None | |
before_batch | NoneType | None | |
after_batch | NoneType | None | |
after_iter | NoneType | None | |
create_batches | NoneType | None | |
create_item | NoneType | None | |
create_batch | NoneType | None | |
retain | NoneType | None | |
get_idxs | NoneType | None | |
sample | NoneType | None | |
shuffle_fn | NoneType | None | |
do_batch | NoneType | None | |
Returns | DataLoaders |
TfmdLists
TfmdLists (items=None, *rest, use_list=False, match=None)
A Pipeline
of tfms
applied to a collection of items
Type | Default | Details | |
---|---|---|---|
items | list | Items to apply Transform s to |
|
use_list | bool | None | Use list in L |
decode_at
decode_at (o, idx)
Decoded item at idx
Exported source
def decode_at(o, idx):
"Decoded item at `idx`"
return o.decode(o[idx])
show_at
show_at (o, idx, **kwargs)
Exported source
def show_at(o, idx, **kwargs):
"Show item at `idx`",
return o.show(o[idx], **kwargs)
A TfmdLists
combines a collection of object with a Pipeline
. tfms
can either be a Pipeline
or a list of transforms, in which case, it will wrap them in a Pipeline
. use_list
is passed along to L
with the items
and split_idx
are passed to each transform of the Pipeline
. do_setup
indicates if the Pipeline.setup
method should be called during initialization.
class _IntFloatTfm(Transform):
def encodes(self, o): return TitledInt(o)
def decodes(self, o): return TitledFloat(o)
=_IntFloatTfm()
int2f_tfm
def _neg(o): return -o
= Transform(_neg, _neg) neg_tfm
= L([1.,2.,3.]); tfms = [neg_tfm, int2f_tfm]
items = TfmdLists(items, tfms=tfms)
tl 0], TitledInt(-1))
test_eq_type(tl[1], TitledInt(-2))
test_eq_type(tl[2]), TitledFloat(3.))
test_eq_type(tl.decode(tl[lambda: show_at(tl, 2), '-3')
test_stdout(float, float, TitledInt])
test_eq(tl.types, [ tl
TfmdLists: [1.0, 2.0, 3.0]
tfms - [_neg:
encodes: (object,object) -> _negdecodes: (object,object) -> _neg, _IntFloatTfm:
encodes: (object,object) -> encodes
decodes: (object,object) -> decodes
]
# add splits to TfmdLists
= [[0,2],[1]]
splits = TfmdLists(items, tfms=tfms, splits=splits)
tl 2)
test_eq(tl.n_subsets, 0))
test_eq(tl.train, tl.subset(1))
test_eq(tl.valid, tl.subset(0]])
test_eq(tl.train.items, items[splits[1]])
test_eq(tl.valid.items, items[splits[0)
test_eq(tl.train.tfms.split_idx, 1)
test_eq(tl.valid.tfms.split_idx, 0)
test_eq(tl.train.new_empty().split_idx, 1)
test_eq(tl.valid.new_empty().split_idx,
test_eq_type(tl.splits, L(splits))assert not tl.overlapping_splits()
= pd.DataFrame(dict(a=[1,2,3],b=[2,3,4]))
df = TfmdLists(df, lambda o: o.a+1, splits=[[0],[1,2]])
tl 1,2], [3,4])
test_eq(tl[= tl.subset(0)
tr 2])
test_eq(tr[:], [= tl.subset(1)
val 3,4]) test_eq(val[:], [
class _B(Transform):
def __init__(self): self.m = 0
def encodes(self, o): return o+self.m
def decodes(self, o): return o-self.m
def setups(self, items):
print(items)
self.m = tensor(items).float().mean().item()
# test for setup, which updates `self.m`
= TfmdLists(items, _B())
tl 2) test_eq(tl.m,
TfmdLists: [1.0, 2.0, 3.0]
tfms - []
Here’s how we can use TfmdLists.setup
to implement a simple category list, getting labels from a mock file list:
class _Cat(Transform):
= 1
order def encodes(self, o): return int(self.o2i[o])
def decodes(self, o): return TitledStr(self.vocab[o])
def setups(self, items): self.vocab,self.o2i = uniqueify(L(items), sort=True, bidir=True)
= _Cat()
tcat
def _lbl(o): return TitledStr(o.split('_')[0])
# Check that tfms are sorted by `order` & `_lbl` is called first
= ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
fns = TfmdLists(fns, [tcat,_lbl])
tl = ['cat','dog']
exp_voc
test_eq(tcat.vocab, exp_voc)
test_eq(tl.tfms.vocab, exp_voc)
test_eq(tl.vocab, exp_voc)1,0,0,0,1))
test_eq(tl, (for o in tl], ('dog','cat','cat','cat','dog')) test_eq([tl.decode(o)
#Check only the training set is taken into account for setup
= TfmdLists(fns, [tcat,_lbl], splits=[[0,4], [1,2,3]])
tl 'dog']) test_eq(tcat.vocab, [
= NegTfm(split_idx=1)
tfm = TfmdLists(start, A())
tds = TfmdDL(tds, after_batch=tfm, bs=4)
tdl = tdl.one_batch()
x 4))
test_eq(x, torch.arange(= 1
tds.split_idx = tdl.one_batch()
x -torch.arange(4))
test_eq(x, = 0
tds.split_idx = tdl.one_batch()
x 4)) test_eq(x, torch.arange(
= TfmdLists(start, A())
tds = TfmdDL(tds, after_batch=NegTfm(), bs=4)
tdl 0], start[0])
test_eq(tdl.dataset[len(tdl), (len(tds)-1)//4+1)
test_eq(4)
test_eq(tdl.bs, '0\n1\n2\n3') test_stdout(tdl.show_batch,
TfmdLists.subset
TfmdLists.subset (i)
New TfmdLists
with same tfms that only includes items in i
th split
TfmdLists.infer_idx
TfmdLists.infer_idx (x)
Finds the index where self.tfms
can be applied to x
, depending on the type of x
TfmdLists.infer
TfmdLists.infer (x)
Apply self.tfms
to x
starting at the right tfm depending on the type of x
def mult(x): return x*2
= 2
mult.order
= ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','dog_1.jpg']
fns = TfmdLists(fns, [_lbl,_Cat(),mult])
tl
'dog_45.jpg'), 0)
test_eq(tl.infer_idx('dog_45.jpg'), 2)
test_eq(tl.infer(
4), 2)
test_eq(tl.infer_idx(4), 8)
test_eq(tl.infer(
lambda: tl.infer_idx(2.0))
test_fail(lambda: tl.infer(2.0)) test_fail(
Datasets
Datasets (items:list=None, tfms:MutableSequence|Pipeline=None, tls:TfmdLists=None, n_inp:int=None, dl_type=None, use_list:bool=None, do_setup:bool=True, split_idx:int=None, train_setup:bool=True, splits:list=None, types=None, verbose:bool=False)
A dataset that creates a tuple from each tfms
Type | Default | Details | |
---|---|---|---|
items | list | None | List of items to create Datasets |
tfms | collections.abc.MutableSequence | fastcore.transform.Pipeline | None | List of Transform (s) or Pipeline to apply |
tls | TfmdLists | None | If None, self.tls is generated from items and tfms |
n_inp | int | None | Number of elements in Datasets tuple that should be considered part of input |
dl_type | NoneType | None | Default type of DataLoader used when function FilteredBase.dataloaders is called |
use_list | bool | None | Use list in L |
do_setup | bool | True | Call setup() for Transform |
split_idx | int | None | Apply Transform (s) to training or validation set. 0 for training set and 1 for validation set |
train_setup | bool | True | Apply Transform (s) only on training DataLoader |
splits | list | None | Indices for training and validation sets |
types | NoneType | None | Types of data in items |
verbose | bool | False | Print verbose output |
A Datasets
creates a tuple from items
(typically input,target) by applying to them each list of Transform
(or Pipeline
) in tfms
. Note that if tfms
contains only one list of tfms
, the items given by Datasets
will be tuples of one element.
n_inp
is the number of elements in the tuples that should be considered part of the input and will default to 1 if tfms
consists of one set of transforms, len(tfms)-1
otherwise. In most cases, the number of elements in the tuples spit out by Datasets
will be 2 (for input,target) but it can happen that there is 3 (Siamese networks or tabular data) in which case we need to be able to determine when the inputs end and the targets begin.
= [1,2,3,4]
items = Datasets(items, [[neg_tfm,int2f_tfm], [add(1)]])
dsets = dsets[0]
t -1,2))
test_eq(t, (0,1,2], [(-1,2),(-2,3),(-3,4)])
test_eq(dsets[1)
test_eq(dsets.n_inp, dsets.decode(t)
(1.0, 2)
class Norm(Transform):
def encodes(self, o): return (o-self.m)/self.s
def decodes(self, o): return (o*self.s)+self.m
def setups(self, items):
= tensor(items).float()
its self.m,self.s = its.mean(),its.std()
= [1,2,3,4]
items = Norm()
nrm = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm,nrm]])
dsets
= zip(*dsets)
x,y 0)
test_close(tensor(y).mean(), 1)
test_close(tensor(y).std(), -1,-2,-3,-4,))
test_eq(x, (-2.5)
test_eq(nrm.m, lambda:show_at(dsets, 1), '-2')
test_stdout(
test_eq(dsets.m, nrm.m)
test_eq(dsets.norm.m, nrm.m) test_eq(dsets.train.norm.m, nrm.m)
= ['dog_0.jpg','cat_0.jpg','cat_2.jpg','cat_1.jpg','kid_1.jpg']
test_fns = _Cat()
tcat = Datasets(test_fns, [[tcat,_lbl]], splits=[[0,1,2], [3,4]])
dsets 'cat','dog'])
test_eq(tcat.vocab, [1,),(0,),(0,)])
test_eq(dsets.train, [(0], (0,))
test_eq(dsets.valid[lambda: show_at(dsets.train, 0), "dog") test_stdout(
= [0,1,2,3,4]
inp = Datasets(inp, tfms=[None])
dsets
*dsets[2], 2) # Retrieve one item (subset 0 is the default)
test_eq(1,2], [(1,),(2,)]) # Retrieve two items by index
test_eq(dsets[= [True,False,False,True,False]
mask 0,),(3,)]) # Retrieve two items by mask test_eq(dsets[mask], [(
= pd.DataFrame(dict(a=[5,1,2,3,4]))
inp = Datasets(inp, tfms=attrgetter('a')).subset(0)
dsets *dsets[2], 2) # Retrieve one item (subset 0 is the default)
test_eq(1,2], [(1,),(2,)]) # Retrieve two items by index
test_eq(dsets[= [True,False,False,True,False]
mask 5,),(3,)]) # Retrieve two items by mask test_eq(dsets[mask], [(
#test n_inp
= [0,1,2,3,4]
inp = Datasets(inp, tfms=[None])
dsets 1)
test_eq(dsets.n_inp, = Datasets(inp, tfms=[[None],[None],[None]])
dsets 2)
test_eq(dsets.n_inp, = Datasets(inp, tfms=[[None],[None],[None]], n_inp=1)
dsets 1) test_eq(dsets.n_inp,
# splits can be indices
= Datasets(range(5), tfms=[None], splits=[tensor([0,2]), [1,3,4]])
dsets
0), [(0,),(2,)])
test_eq(dsets.subset(0,),(2,)]) # Subset 0 is aliased to `train`
test_eq(dsets.train, [(1), [(1,),(3,),(4,)])
test_eq(dsets.subset(1,),(3,),(4,)]) # Subset 1 is aliased to `valid`
test_eq(dsets.valid, [(*dsets.valid[2], 4)
test_eq(#assert '[(1,),(3,),(4,)]' in str(dsets) and '[(0,),(2,)]' in str(dsets)
dsets
(#5) [(0,),(1,),(2,),(3,),(4,)]
# splits can be boolean masks (they don't have to cover all items, but must be disjoint)
= [[False,True,True,False,True], [True,False,False,False,False]]
splits = Datasets(range(5), tfms=[None], splits=splits)
dsets
1,),(2,),(4,)])
test_eq(dsets.train, [(0,)]) test_eq(dsets.valid, [(
# apply transforms to all items
= [[lambda x: x*2,lambda x: x+1]]
tfm = [[1,2],[0,3,4]]
splits = Datasets(range(5), tfm, splits=splits)
dsets 3,),(5,)])
test_eq(dsets.train,[(1,),(7,),(9,)])
test_eq(dsets.valid,[(False,True], [(5,)]) test_eq(dsets.train[
# only transform subset 1
class _Tfm(Transform):
=1
split_idxdef encodes(self, x): return x*2
def decodes(self, x): return TitledStr(x//2)
= Datasets(range(5), [_Tfm()], splits=[[1,2],[0,3,4]])
dsets 1,),(2,)])
test_eq(dsets.train,[(0,),(6,),(8,)])
test_eq(dsets.valid,[(False,True], [(2,)])
test_eq(dsets.train[ dsets
(#5) [(0,),(1,),(2,),(3,),(4,)]
#A context manager to change the split_idx and apply the validation transform on the training set
= dsets.train
ds with ds.set_split_idx(1):
2,),(4,)])
test_eq(ds,[(1,),(2,)]) test_eq(dsets.train,[(
= Datasets(range(5), [_Tfm(),noop], splits=[[1,2],[0,3,4]])
dsets 1,1),(2,2)])
test_eq(dsets.train,[(0,0),(6,3),(8,4)]) test_eq(dsets.valid,[(
= torch.arange(0,50)
start = Datasets(start, [A()])
tds = TfmdDL(tds, after_item=NegTfm(), bs=4)
tdl = tdl.one_batch()
b 0,),(1,),(2,),(3,)))
test_eq(tdl.decode_batch(b), (("0\n1\n2\n3") test_stdout(tdl.show_batch,
# only transform subset 1
class _Tfm(Transform):
=1
split_idxdef encodes(self, x): return x*2
= Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]]) dsets
# only transform subset 1
class _Tfm(Transform):
=1
split_idxdef encodes(self, x): return x*2
= Datasets(range(8), [None], splits=[[1,2,5,7],[0,3,4,6]])
dsets = dsets.dataloaders(bs=4, after_batch=_Tfm(), shuffle=False, device=torch.device('cpu'))
dls 1,2,5, 7]),)])
test_eq(dls.train, [(tensor([0,6,8,12]),)])
test_eq(dls.valid, [(tensor([1) test_eq(dls.n_inp,
Methods
= [1,2,3,4]
items = Datasets(items, [[neg_tfm,int2f_tfm]]) dsets
Datasets.dataloaders
Datasets.dataloaders (bs:int=64, shuffle_train:bool=None, shuffle:bool=True, val_shuffle:bool=False, n:int=None, path:str|Path='.', dl_type:TfmdDL=None, dl_kwargs:list=None, device:torch.device=None, drop_last:bool=None, val_bs:int=None, num_workers:int=None, verbose:bool=False, do_setup:bool=True, pin_memory=False, timeout=0, batch_size=None, indexed=None, persistent_workers=False, pin_memory_device='', wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)
Get a DataLoaders
Type | Default | Details | |
---|---|---|---|
bs | int | 64 | Batch size |
shuffle_train | bool | None | (Deprecated, use shuffle ) Shuffle training DataLoader |
shuffle | bool | True | Shuffle training DataLoader |
val_shuffle | bool | False | Shuffle validation DataLoader |
n | int | None | Size of Datasets used to create DataLoader |
path | str | Path | . | Path to put in DataLoaders |
dl_type | TfmdDL | None | Type of DataLoader |
dl_kwargs | list | None | List of kwargs to pass to individual DataLoader s |
device | torch.device | None | Device to put DataLoaders |
drop_last | bool | None | Drop last incomplete batch, defaults to shuffle |
val_bs | int | None | Validation batch size, defaults to bs |
num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |
verbose | bool | False | Whether to print verbose logs |
do_setup | bool | True | Whether to run setup() for batch transform(s) |
pin_memory | bool | False | |
timeout | int | 0 | |
batch_size | NoneType | None | |
indexed | NoneType | None | |
persistent_workers | bool | False | |
pin_memory_device | str | ||
wif | NoneType | None | |
before_iter | NoneType | None | |
after_item | NoneType | None | |
before_batch | NoneType | None | |
after_batch | NoneType | None | |
after_iter | NoneType | None | |
create_batches | NoneType | None | |
create_item | NoneType | None | |
create_batch | NoneType | None | |
retain | NoneType | None | |
get_idxs | NoneType | None | |
sample | NoneType | None | |
shuffle_fn | NoneType | None | |
do_batch | NoneType | None | |
Returns | DataLoaders |
Used to create dataloaders. You may prepend ‘val_’ as in val_shuffle
to override functionality for the validation set. dl_kwargs
gives finer per dataloader control if you need to work with more than one dataloader.
Datasets.decode
Datasets.decode (o, full=True)
Compose decode
of all tuple_tfms
then all tfms
on i
*dsets[0], -1)
test_eq(*dsets.decode((-1,)), 1) test_eq(
Datasets.show
Datasets.show (o, ctx=None, **kwargs)
Show item o
in ctx
lambda:dsets.show(dsets[1]), '-2') test_stdout(
Datasets.new_empty
Datasets.new_empty ()
Create a new empty version of the self
, keeping only the transforms
= [1,2,3,4]
items = Norm()
nrm = Datasets(items, [[neg_tfm,int2f_tfm], [neg_tfm]])
dsets = dsets.new_empty()
empty test_eq(empty.items, [])
Add test set for inference
# only transform subset 1
class _Tfm1(Transform):
=0
split_idxdef encodes(self, x): return x*3
= Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dsets 3,),(6,),(15,),(21,)])
test_eq(dsets.train, [(0,),(6,),(8,),(12,)]) test_eq(dsets.valid, [(
test_set
test_set (dsets:__main__.Datasets|__main__.TfmdLists, test_items, rm_tfms=None, with_labels:bool=False)
Create a test set from test_items
using validation transforms of dsets
Type | Default | Details | |
---|---|---|---|
dsets | main.Datasets | main.TfmdLists | Map- or iterable-style dataset from which to load the data | |
test_items | Items in test dataset | ||
rm_tfms | NoneType | None | Start index of Transform (s) from validation set in dsets to apply |
with_labels | bool | False | Whether the test items contain labels |
class _Tfm1(Transform):
=0
split_idxdef encodes(self, x): return x*3
= Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dsets 3,),(6,),(15,),(21,)])
test_eq(dsets.train, [(0,),(6,),(8,),(12,)])
test_eq(dsets.valid, [(
#Tranform of the validation set are applied
= test_set(dsets, [1,2,3])
tst 2,),(4,),(6,)]) test_eq(tst, [(
DataLoaders.test_dl
DataLoaders.test_dl (test_items, rm_type_tfms=None, with_labels:bool=False, bs:int=64, shuffle:bool=False, num_workers:int=None, verbose:bool=False, do_setup:bool=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, persistent_workers=False, pin_memory_device='', wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None)
Create a test dataloader from test_items
using validation transforms of dls
Type | Default | Details | |
---|---|---|---|
test_items | Items in test dataset | ||
rm_type_tfms | NoneType | None | Start index of Transform (s) from validation set in dsets to apply |
with_labels | bool | False | Whether the test items contain labels |
bs | int | 64 | Size of batch |
shuffle | bool | False | Whether to shuffle data |
num_workers | int | None | Number of CPU cores to use in parallel (default: All available up to 16) |
verbose | bool | False | Whether to print verbose logs |
do_setup | bool | True | Whether to run setup() for batch transform(s) |
pin_memory | bool | False | |
timeout | int | 0 | |
batch_size | NoneType | None | |
drop_last | bool | False | |
indexed | NoneType | None | |
n | NoneType | None | |
device | NoneType | None | |
persistent_workers | bool | False | |
pin_memory_device | str | ||
wif | NoneType | None | |
before_iter | NoneType | None | |
after_item | NoneType | None | |
before_batch | NoneType | None | |
after_batch | NoneType | None | |
after_iter | NoneType | None | |
create_batches | NoneType | None | |
create_item | NoneType | None | |
create_batch | NoneType | None | |
retain | NoneType | None | |
get_idxs | NoneType | None | |
sample | NoneType | None | |
shuffle_fn | NoneType | None | |
do_batch | NoneType | None |
= Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dsets = dsets.dataloaders(bs=4, device=torch.device('cpu')) dls
= Datasets(range(8), [[_Tfm(),_Tfm1()]], splits=[[1,2,5,7],[0,3,4,6]])
dsets = dsets.dataloaders(bs=4, device=torch.device('cpu'))
dls = dls.test_dl([2,3,4,5])
tst_dl 1)
test_eq(tst_dl._n_inp, list(tst_dl), [(tensor([ 4, 6, 8, 10]),)])
test_eq(#Test you can change transforms
= dls.test_dl([2,3,4,5], after_item=add1)
tst_dl list(tst_dl), [(tensor([ 5, 7, 9, 11]),)]) test_eq(