Callbacks which work with a learner's data
lbls = np.random.randint(0, 2, size=(10)) # Dataset of size 10 (train=8, valid=2)
is_valid = lambda i: i >= 8
dblock = DataBlock(blocks=[CategoryBlock],
getters=[lambda i: lbls[i]], splitter=FuncSplitter(is_valid))
dset = dblock.datasets(list(range(10)))
item_tfms = [ToTensor()]
wgts = range(8) # len(wgts) == 8
dls = dset.weighted_dataloaders(bs=1, wgts=wgts, after_item=item_tfms)
dls.show_batch() # if len(wgts) != 8, this will fail"
n = 160
dsets = Datasets(torch.arange(n).float())
dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)
learn = synth_learner(data=dls, cbs=CollectDataCallback)
learn.fit(1)
t = concat(*learn.collect_data.data.itemgot(0,0))
plt.hist(t.numpy());
dls = dblock.weighted_dataloaders(list(range(10)), wgts, bs=1)
dls.show_batch()
dls = dsets.partial_dataloaders(partial_n=32, bs=16)
assert len(dls[0])==2
for batch in dls[0]:
assert len(batch[0])==16