Text-only LLMs miss half the world. Images, diagrams, charts, photos. Multimodal models see and read - understanding both modalities together.
The multimodal challenge
Different modalities have different representations:
- Text: discrete tokens
- Images: continuous pixels
- Audio: waveforms
How do you combine them in one model?
See vision meet language: Multimodal LLM Animation
Architecture patterns
Vision encoder + LLM
Most common approach:
- Process image with vision encoder (CLIP, ViT)
- Project to LLM embedding space
- Feed image embeddings + text tokens to LLM
Image → Vision Encoder → Projection → [IMG1][IMG2]...[IMGn]
↓
LLM receives: [IMG tokens] + [text tokens]
LLaVA, GPT-4V, Gemini use variants of this.
Cross-attention
Add cross-attention layers. Text attends to image features.
Flamingo approach:
# In decoder layer
x = self_attention(x)
x = cross_attention(x, image_features) # attend to vision
x = ffn(x)
Early fusion
Treat image patches as tokens from the start.
# ViT style
image_patches = patchify(image) # (N, patch_dim)
image_tokens = patch_embed(image_patches)
# Concatenate with text tokens
all_tokens = torch.cat([image_tokens, text_tokens], dim=1)
output = transformer(all_tokens)
CLIP: Connecting vision and language
Contrastive learning on image-text pairs.
Two encoders:
- Image encoder: image → embedding
- Text encoder: text → embedding
Training: matching pairs should have similar embeddings.
# Simplified CLIP loss
image_emb = image_encoder(images) # (batch, dim)
text_emb = text_encoder(texts) # (batch, dim)
# Cosine similarity matrix
logits = image_emb @ text_emb.T / temperature
# Cross-entropy loss (each image matches one text and vice versa)
labels = torch.arange(batch_size)
loss_i = F.cross_entropy(logits, labels)
loss_t = F.cross_entropy(logits.T, labels)
loss = (loss_i + loss_t) / 2
Building a simple multimodal model
class SimpleMMLLM(nn.Module):
def __init__(self):
super().__init__()
self.vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base")
self.vision_proj = nn.Linear(768, llm_hidden)
self.llm = AutoModelForCausalLM.from_pretrained("llama-7b")
def forward(self, image, input_ids):
# Encode image
vision_outputs = self.vision_encoder(image)
image_features = vision_outputs.last_hidden_state
image_embeds = self.vision_proj(image_features)
# Get text embeddings
text_embeds = self.llm.get_input_embeddings()(input_ids)
# Concatenate
inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
# Forward through LLM
return self.llm(inputs_embeds=inputs_embeds)
Capabilities
With image + language:
- Visual question answering
- Image captioning
- Document understanding
- Chart/graph interpretation
- Visual reasoning
- Image-guided code generation
Challenges
- Resolution: higher res = more tokens = expensive
- Hallucination: models make up image contents
- Spatial reasoning: often weak
- Fine-grained understanding: small details missed