AnyModal: Train Multimodal LLMs in PyTorch
Ritabrata Maiti
Posted on November 19, 2024
Today, I want to introduce an open-source framework I’ve been working on: AnyModal.
Introduction
During my work on machine learning projects, I struggled to find flexible solutions for training multimodal LLMs. While there are plenty of great tools for specific tasks—like image classification or audio processing—there was no straightforward way to combine these modalities with large language models (LLMs). The process was often tedious, involving boilerplate code, custom integration, and a lot of trial and error to make different components work together.
This frustration led me to build AnyModal, a framework designed to reduce the complexity of multimodal AI development. It provides a modular, reusable structure that makes it easier for developers and researchers to combine diverse data types and experiment with new ideas without reinventing the wheel every time.
The Goal
AnyModal is built with the following objectives in mind:
Reduce Boilerplate Code
Combining modalities like images or audio with LLMs typically involves repetitive steps—preprocessing, encoding, tokenizing, and integrating. AnyModal minimizes this boilerplate by providing reusable modules for common tasks, letting developers focus on building smarter systems faster.
Enable Seamless Integration
Whether you're working with images using a Vision Transformer (ViT) or audio spectrograms, AnyModal offers plug-and-play components that simplify the integration process. This makes it easy to handle multiple data types within a single framework.
Encourage Experimentation and Customization
AnyModal supports rapid prototyping while offering the flexibility to customize components like feature encoders, projection layers, and tokenizers. It’s versatile enough for both quick experiments and production-level deployments.
Example Usage: Integrating Images with LLMs
Here’s a detailed example of how AnyModal simplifies the integration of image data into LLMs:
1. Install Dependencies
pip install torch transformers datasets torchvision tqdm
2. Initialize Vision Components
from transformers import ViTImageProcessor, ViTForImageClassification
# Load a pre-trained Vision Transformer
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
vision_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
# Define a Vision Encoder to extract feature embeddings
from vision import VisionEncoder
vision_encoder = VisionEncoder(vision_model)
3. Initialize Tokenizer and LLM
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load a pre-trained LLM and its tokenizer
llm_tokenizer = AutoTokenizer.from_pretrained("gpt2")
llm_model = AutoModelForCausalLM.from_pretrained("gpt2")
4. Define a Projection Layer
from vision import Projector
# Create a projection layer to map vision embeddings to LLM token space
vision_tokenizer = Projector(
in_features=vision_model.config.hidden_size,
out_features=768
)
5. Combine Everything with AnyModal
from anymodal import MultiModalModel
# Build the multimodal model
multimodal_model = MultiModalModel(
input_processor=None,
input_encoder=vision_encoder,
input_tokenizer=vision_tokenizer,
language_tokenizer=llm_tokenizer,
language_model=llm_model,
input_start_token='<|imstart|>',
input_end_token='<|imend|>',
prompt_text="Describe this image: "
)
6. Training and Inference
Training involves processing batches of image-text pairs and optimizing the model:
from torch.utils.data import DataLoader
from datasets import load_dataset
# Load a sample dataset
dataset = load_dataset("image_caption_dataset", split="train")
# Prepare DataLoader
train_loader = DataLoader(dataset, batch_size=2, shuffle=True)
# Training Loop
optimizer = torch.optim.AdamW(multimodal_model.parameters(), lr=3e-4)
for epoch in range(10):
for batch in train_loader:
optimizer.zero_grad()
logits, loss = multimodal_model(batch)
loss.backward()
optimizer.step()
# Generate captions
sample_input = dataset[0]['image']
generated_caption = multimodal_model.generate(sample_input, max_new_tokens=30)
print("Generated Caption:", generated_caption)
Current Status
AnyModal is currently in its early stages, with the latest version supporting tasks like:
- LaTeX OCR
- Chest X-Ray Captioning (in progress)
- Image Captioning
Future planned features include support for visual question answering and audio captioning.
As the framework evolves, I’m focusing on expanding its functionality, refining the codebase, and addressing community feedback to move towards a stable release.
Links
- GitHub: https://github.com/ritabratamaiti/AnyModal
- Reddit: https://www.reddit.com/r/AnyModal/
- Hugging Face: https://huggingface.co/AnyModal
If you’re looking for a way to simplify multimodal AI development, give AnyModal a try. I’d love to hear your feedback or ideas for new features. Contributions are always welcome!
Posted on November 19, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.