A basic model that can be used on tabular data
/usr/local/lib/python3.8/dist-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0




Rule of thumb to pick embedding size corresponding to n_cat

Through trial and error, this general rule takes the lower of two values:

  • A dimension space of 600
  • A dimension space equal to 1.6 times the cardinality of the variable to 0.56.

This provides a good starter for a good embedding space for your variables. For more advanced users who wish to lean into this practice, you can tweak these values to your discretion. It is not uncommon for slight adjustments to this general formula to provide more success.


get_emb_sz(to, sz_dict=None)

Get default embedding size from TabularPreprocessor proc or the ones in sz_dict

class TabularModel[source]

TabularModel(emb_szs, n_cont, out_sz, layers, ps=None, embed_p=0.0, y_range=None, use_bn=True, bn_final=False, bn_cont=True, act_cls=ReLU(inplace=True)) :: Module

Basic model for tabular data.

This model expects your cat and cont variables seperated. cat is passed through an Embedding layer and potential Dropout, while cont is passed though potential BatchNorm1d. Afterwards both are concatenated and passed through a series of LinBnDrop, before a final Linear layer corresponding to the expected outputs.

emb_szs = [(4,2), (17,8)]
m = TabularModel(emb_szs, n_cont=2, out_sz=2, layers=[200,100]).eval()
x_cat = torch.tensor([[2,12]]).long()
x_cont = torch.tensor([[0.7633, -0.1887]]).float()
out = m(x_cat, x_cont)


tabular_config(ps=None, embed_p=0.0, y_range=None, use_bn=True, bn_final=False, bn_cont=True, act_cls=ReLU(inplace=True))

Convenience function to easily create a config for TabularModel

Any direct setup of TabularModel's internals should be passed through here:

config = tabular_config(embed_p=0.6, use_bn=False); config
{'embed_p': 0.6, 'use_bn': False}