DataLoaders

The DataLoader class
bs = 4
letters = list(string.ascii_lowercase)

DataLoader helpers

fastai includes a replacement for Pytorch’s DataLoader which is largely API-compatible, and adds a lot of useful functionality and flexibility. Before we look at the class, there are a couple of helpers we’ll need to define.


source

fa_collate

 fa_collate (t)

A replacement for PyTorch default_collate which maintains types and handles Sequences

#e.g. x is int, y is tuple
t = [(1,(2,3)),(1,(2,3))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])

t = [(1,(2,(3,4))),(1,(2,(3,4)))]
test_eq(fa_collate(t), default_collate(t))
test_eq(L(fa_collate(t)).map(type), [Tensor,tuple])
test_eq(L(fa_collate(t)[1]).map(type), [Tensor,tuple])

source

fa_convert

 fa_convert (t)

A replacement for PyTorch default_convert which maintains types and handles Sequences

t0 = array([1,2])
t = [t0,(t0,t0)]

test_eq(fa_convert(t), default_convert(t))
test_eq(L(fa_convert(t)).map(type), [Tensor,tuple])

source

SkipItemException

Raised to notify DataLoader to skip an item


source

collate_error

 collate_error (e:Exception, batch)

Raises error when the batch could not collate, stating what items in the batch are different sizes and their types


source

DataLoader

 DataLoader (dataset=None, bs=None, num_workers=0, pin_memory=False,
             timeout=0, batch_size=None, shuffle=False, drop_last=False,
             indexed=None, n=None, device=None, persistent_workers=False,
             pin_memory_device='', wif=None, before_iter=None,
             after_item=None, before_batch=None, after_batch=None,
             after_iter=None, create_batches=None, create_item=None,
             create_batch=None, retain=None, get_idxs=None, sample=None,
             shuffle_fn=None, do_batch=None)

Inherit from this to have all attr accesses in self._xtra passed down to self.default

Arguments to DataLoader:

  • dataset: dataset from which to load the data. Can be either map-style or iterable-style dataset.
  • bs (int): how many samples per batch to load (if batch_size is provided then batch_size will override bs). If bs=None, then it is assumed that dataset.__getitem__ returns a batch.
  • num_workers (int): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
  • pin_memory (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
  • timeout (float>0): the timeout value in seconds for collecting a batch from workers.
  • batch_size (int): It is only provided for PyTorch compatibility. Use bs.
  • shuffle (bool): If True, then data is shuffled every time dataloader is fully read/iterated.
  • drop_last (bool): If True, then the last incomplete batch is dropped.
  • indexed (bool): The DataLoader will make a guess as to whether the dataset can be indexed (or is iterable), but you can override it with this parameter. True by default.
  • n (int): Defaults to len(dataset). If you are using iterable-style dataset, you can specify the size with n.
  • device (torch.device): Defaults to default_device() which is CUDA by default. You can specify device as torch.device('cpu').

Override create_item and use the default infinite sampler to get a stream of unknown length (stop() when you want to stop the stream).

class RandDL(DataLoader):
    def create_item(self, s):
        r = random.random()
        return r if r<0.95 else stop()

L(RandDL())
(#9) [0.09071201211613367,0.03249811556595483,0.6517029228593939,0.8584412116263038,0.759838440232556,0.3725873327679504,0.1445316323722865,0.18876233969606782,0.25518635091544917]
L(RandDL(bs=4, drop_last=True)).map(len)
(#1) [4]
dl = RandDL(bs=4, num_workers=4, drop_last=True)
L(dl).map(len)
(#1) [4]
test_num_workers = 0 if sys.platform in ("win32","darwin") else 4
test_eq(dl.fake_l.num_workers, test_num_workers)
with dl.fake_l.no_multiproc(): 
    test_eq(dl.fake_l.num_workers, 0)
    L(dl).map(len)
test_eq(dl.fake_l.num_workers, test_num_workers)
def _rand_item(s):
    r = random.random()
    return r if r<0.95 else stop()

L(DataLoader(create_item=_rand_item))
(#2) [0.624781366539204,0.39823513973618685]

If you don’t set bs, then dataset is assumed to provide an iterator or a __getitem__ that returns a batch.

ds1 = DataLoader(letters)
test_eq(L(ds1), letters)
test_eq(len(ds1), 26)

test_shuffled(L(DataLoader(letters, shuffle=True)), letters)

ds1 = DataLoader(letters, indexed=False)
test_eq(L(ds1), letters)
test_eq(len(ds1), 26)

t2 = L(tensor([0,1,2]),tensor([3,4,5]))
ds2 = DataLoader(t2)
test_eq_type(L(ds2), t2)

t3 = L(array([0,1,2], dtype=np.int64),array([3,4,5], dtype=np.int64))
ds3 = DataLoader(t3)
test_eq_type(L(ds3), t3.map(tensor))

ds4 = DataLoader(t3, create_batch=noop, after_iter=lambda: setattr(t3, 'f', 1))
test_eq_type(L(ds4), t3)
test_eq(t3.f, 1)

If you do set bs, then dataset is assumed to provide an iterator or a __getitem__ that returns a single item of a batch.

def twoepochs(d): return ' '.join(''.join(list(o)) for _ in range(2) for o in d)
ds1 = DataLoader(letters, bs=4, drop_last=True, num_workers=0)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx abcd efgh ijkl mnop qrst uvwx')

ds1 = DataLoader(letters,4,num_workers=2)
test_eq(twoepochs(ds1), 'abcd efgh ijkl mnop qrst uvwx yz abcd efgh ijkl mnop qrst uvwx yz')

ds1 = DataLoader(range(12), bs=4, num_workers=3)
test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])))

ds1 = DataLoader([str(i) for i in range(11)], bs=4, after_iter=lambda: setattr(t3, 'f', 2))
test_eq_type(L(ds1), L(['0','1','2','3'],['4','5','6','7'],['8','9','10']))
test_eq(t3.f, 2)

it = iter(DataLoader(map(noop,range(20)), bs=4, num_workers=1))
test_eq_type([next(it) for _ in range(3)], [tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10,11])])

Iterable dataloaders require specific tests.

class DummyIterableDataset(IterableDataset):
    def __iter__(self):
        yield from range(11)

ds1 = DataLoader(DummyIterableDataset(), bs=4)
# Check it yields fine, and check we can do multiple passes
for i in range(3):
    test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7]),tensor([8,9,10])))

# Check `drop_last` works fine (with multiple passes, since this will prematurely terminate the iterator)
ds1 = DataLoader(DummyIterableDataset(), bs=4, drop_last=True)
for i in range(3):
    test_eq_type(L(ds1), L(tensor([0,1,2,3]),tensor([4,5,6,7])))
class SleepyDL(list):
    def __getitem__(self,i):
        time.sleep(random.random()/50)
        return super().__getitem__(i)

t = SleepyDL(letters)




dl = DataLoader(t, shuffle=True, num_workers=1)
test_shuffled(L(dl), letters)
test_shuffled(L(dl), L(dl))
L(dl)
CPU times: user 3.35 ms, sys: 890 µs, total: 4.24 ms
Wall time: 307 ms
CPU times: user 6.93 ms, sys: 860 µs, total: 7.79 ms
Wall time: 333 ms
CPU times: user 7.78 ms, sys: 722 µs, total: 8.51 ms
Wall time: 331 ms
(#26) ['l','h','f','r','z','s','u','x','m','p'...]
class SleepyQueue():
    "Simulate a queue with varying latency"
    def __init__(self, q): self.q=q
    def __iter__(self):
        while True:
            time.sleep(random.random()/100)
            try: yield self.q.get_nowait()
            except queues.Empty: return

q = Queue()
for o in range(30): q.put(o)
it = SleepyQueue(q)

if not (sys.platform == "win32" and IN_NOTEBOOK):
AssertionError: !=:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]
class A(TensorBase): pass

for nw in (0,2):
    t = A(tensor([1,2]))
    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
    b = first(dl)
    test_eq(type(b), A)

    t = (A(tensor([1,2])),)
    dl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=nw)
    b = first(dl)
    test_eq(type(b[0]), A)
list(DataLoader(list(range(50)),bs=32,shuffle=True,num_workers=3))
[tensor([42, 12, 44, 21,  8,  6,  3, 37, 33,  9, 27, 34, 18, 26,  1, 23, 11, 41,
         15,  0, 49,  4, 38, 46, 48, 14, 40, 36, 17, 45, 30, 29]),
 tensor([19, 10, 22, 13, 25, 32, 35,  5,  2, 20, 47, 39, 16, 28, 43,  7, 31, 24])]
class A(TensorBase): pass
t = A(tensor(1,2))

tdl = DataLoader([t,t,t,t,t,t,t,t], bs=4, num_workers=2, after_batch=to_device)
b = first(tdl)
test_eq(type(b), A)

# Unknown attributes are delegated to `dataset`
test_eq(tdl.pop(), tensor(1,2))

Override get_idxs to return the same index until consumption of the DL. This is intented to test consistent sampling behavior when num_workers>1.

class AdamantDL(DataLoader):
    def get_idxs(self):
        r=random.randint(0,self.n-1)
        return [r] * self.n

test_eq(torch.cat(tuple(AdamantDL((list(range(50))),bs=16,num_workers=4))).unique().numel(),1)
# from subprocess import Popen, PIPE
# # test num_workers > 0 in scripts works when python process start method is spawn
# process = Popen(["python", "dltest.py"], stdout=PIPE)
# _, err = process.communicate(timeout=15)
# exit_code = process.wait()
# test_eq(exit_code, 0)