The main class to get your data ready for model training is TabularDataLoaders
and its factory methods. Checkout the tabular tutorial for examples of use.
This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:
cat_names
: the names of the categorical variablescont_names
: the names of the continuous variablesy_names
: the names of the dependent variablesy_block
: theTransformBlock
to use for the targetvalid_idx
: the indices to use for the validation set (defaults to a random split otherwise)bs
: the batch sizeval_bs
: the batch size for the validationDataLoader
(defaults tobs
)shuffle_train
: if we shuffle the trainingDataLoader
or notn
: overrides the numbers of elements in the datasetdevice
: the PyTorch device to use (defaults todefault_device()
)
Let's have a look on an example with the adult dataset:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv', skipinitialspace=True)
df.head()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
dls.show_batch()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
External structured data files can contain unexpected spaces, e.g. after a comma. We can see that in the first row of adult.csv "49, Private,101320, ..."
. Often trimming is needed. Pandas has a convenient parameter skipinitialspace
that is exposed by TabularDataLoaders.from_csv()
). Otherwise category labels use for inference later such as workclass
:Private
will be categorized wrongly to 0 or "#na#"
if training label was read as " Private"
. Let's test this feature.
test_data = {
'age': [49],
'workclass': ['Private'],
'fnlwgt': [101320],
'education': ['Assoc-acdm'],
'education-num': [12.0],
'marital-status': ['Married-civ-spouse'],
'occupation': [''],
'relationship': ['Wife'],
'race': ['White'],
}
input = pd.DataFrame(test_data)
tdl = dls.test_dl(input)
test_ne(0, tdl.dataset.iloc[0]['workclass'])