Tabnet first visit

Weekly Dev Blog
3 min readJan 25, 2021

--

With the growth of the neural networks, DNN is showing a good performance in text, image, audio data.
However, for the everyday tabular data, we’ve yet to see the success compare to XGB LGBM.

Tree-based models have advantages in:

  • decision manifolds, like a super hyperplane (infinite expandable, cuts the table well)
  • interpretable
  • fast to train

DNN

  • able to encode data (representation learning)
  • reduce feature engineering
  • online learning

discuss: reasons why DNN not performant on tabular data

  • non-linear, not restriction on convergence, easy overfit compare to tree ensemble
  • adding more layers can cause overparameterization, this may be why it isn’t performant in tabular data.

How good would it be if we can have a framework that’s both end-to-end, representation and can perform online update for tabular data

build decision boundary using DNN

Tree:

DNN:

we can treat the mask as a decision boundary

the mask+FC+ReLU is like a vanilla decision tree
this is an additive model, each output represents the weight of each condition that affects the final decision

Tabnet model structure

this is a more complex additive model (e.g. given input batch x feature -> single vector)

Layers

  • BN: batch normalization
  • Feature transformer: =FC, calculate feature embedding
  • Split
  • Attention transformer

Overall: Tabnet uses sequential multi-step to construct an additive NN framework

Attentive transformer

  • use last step result to calculate current step Mask, and tries its best to make sure the Mask is sparse and non-repetitive
  • different data point uses different mask allows, different data to use the different feature (instance-wise) (tree: batch)

Feature transformer

  • performs the feature selection of the current step

Tabnet evaluation

Paper data/evaluation:

Performs very similar or better than gbm based models

private data: offer activation prediction

tabular data contains customer profile and sale records

LGBM roc, 0.83
Tabnet, 0.8

Tabnet has 100x the machine cost to achieve the same training speed compare to LGBM

Kaggle

data: gene expression

https://www.kaggle.com/c/lish-moa/data?select=train_features.csv

target: the sample had a positive response for each MoA target.

All top models use Tabnet in their ensemble, but Tabnet doesn’t contribute to large weighting.

Tabnet does show a good promise of using NN for tabular data. Although the performance is still not comparable to GBM models, it can still serve as the online learning component in mixed ensembles for tabular data.

--

--

No responses yet