In this article, I will give a hands-on example (with code) of how one can use the popular PyTorch framework to apply the Vision Transformer, which was suggested in the paper “An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale” (which I reviewed in another post), to a practical computer vision task.

To do that, we will look at the problem of handwritten digit recognition using the well-known MNIST dataset.

Image for post

Examples of MNIST handwritten digits generated using Pyplot

I would like to provide a caveat right away, just to make it clear. I chose the MNIST dataset for this demonstration because it is simple enough so that a model can be trained on it from scratch and used for predictions without any specialized hardware within minutes, not hours or days, so literally anyone with a computer can do it and see how it works. I haven’t tried much to optimize the hyperparameters of the model, and I certainly didn’t have the goal of achieving state-of-the-art accuracy (currently around 99.8% for this dataset) with this approach.

In fact, while I will show that the Vision Transformer can attain a respectable 98%+ accuracy on MNIST, it can be argued that it is not the best tool for this job. Since each image in this dataset is small (just 28x28 pixels) and consists of a single object, applying global attention can only be of limited utility. I might write another post later to examine how this model can be used on a bigger dataset with larger images and a greater variety of classes. For now, I just want to show how it works.

#transformers #attention-model #mnist-dataset #computer-vision #machine-learning

A Practical Demonstration of Using Vision Transformers in PyTorch
17.50 GEEK