Deep Learning has taken over vision, natural language processing, speech recognition, and many other fields achieving astonishing results and even superhuman performance in some. However, the use of deep learning to model tabular data has been relatively limited.
For tabular data, the most common approach is the use of tree-based models and their ensembles. The tree-based models globally select features which reduce the entropy the most. Ensemble methods like bagging, boosting improve these tree-based methods further by reducing the model variance. Recent tree-based ensembles like XGBoost and LightGBM have dominated Kaggle competitions.
TabNet is a neural architecture developed by the research team at Google Cloud AI. It was able to achieve state of the art results on several datasets in both regression and classification problems. It combines the features of neural nets to fit very complex functions and the **feature selection **property of tree-based algorithms. In other words, the model learns to select only the relevant features during the training process. Moreover, contrary to tree-based models which can only do feature-selection globally, the feature selection process in TabNet is instance-wise. Another desirable feature of TabNet is interpretability. Contrary to most of deep learning, where the neural networks act like black boxes, we can interpret which features the models selects in case of TabNet.
In this blog, I will take you through a step-wise beginner-friendly implementation of TabNet in PyTorch. Let’s get started!!
#beginners-guide #tabular-data #implementation #deeplearing #data-science