Year: 2025
Role: Machine Learning Engineer
Duration: 1 week
Relevant links: Github
Summary: Transformer architecture to generate captions given an input image. It was trained on the Flickr30k dataset.
I wanted to learn how to implement transformers from scratch and explore multimodal learning. Image captioning seemed like a great way to combine vision and language in an ambitious yet feasible project within a week. The goal was to build a system that generates accurate and meaningful captions for images trained on the the Flickr30k dataset.
Initially, I planned to train everything from scratch—and successfully did so for the smaller task of MNIST (outlined below). However, for the final model, I leveraged OpenAI’s CLIP model as a pre-trained baseline, which significantly improved performance.
This project deepened my understanding of transformers, multimodal learning, and the power of pre-trained models. While a transformer trained from scratch worked well on MNIST, CLIP embeddings were essential for handling a real-world dataset like Flickr30k. The final model generates high-quality captions by combining deep learning techniques with strategic sampling and ranking.
I started by implementing a basic transformer architecture on a simple problem: MNIST digit classification. I trained a Vision Transformer (ViT) encoder and a transformer-based decoder to process this dataset. Since MNIST is relatively easy, the model performed well, giving me confidence in my approach. You can find the project under this Github repo.
Applying the same approach to my image captioning model, I quickly realized that training from scratch with only 30k images wasn’t sufficient. I needed a pre-trained model.
To improve performance, I leveraged OpenAI’s CLIP model, which is trained on a vast dataset of images and their textual descriptions Thanks to its contrastive learning approach, CLIP generalizes much better than the ViT model I had previously trained.
Using CLIP’s embeddings as input, I trained my decoder to generate captions. Inspired by the Pixtral 12B paper, I simplified the architecture by using the CLS token from the image embedding as the start token, rather than employing cross-attention. This streamlined the model while maintaining performance.
For inference, I generated multiple caption candidates by sampling at different temperature values. To select the best caption, I used CLIP itself to rank the options based on their relevance to the image. This final step significantly improved caption quality.