TP3 Multimodal Learning and its application

third TP

By Mariem ZAOUALI

TP 3 : Multimodal Learning and its application

We will explore multimodal models, focusing on how to fuse information from different data types including text, images, and audio. We will dive into advanced topics, including Joint Optimization of modalities using Contrastive Language–Image Pre-training and diffusion models.

Learning Objectives:

  • Understand how to encode different modalities into compact representations that capture both local and global information.
  • Learn about joint optimization in late-fusion models like CLIP, which handle text and images independently and project them into a shared embedding space.
  • Explore the text-guided diffusion model formulation involving a non-autoregressive decoder that can generate complex outputs while still allowing for multimodal conditioning.

Part 1: Defining A Modality?

Before diving into multimodal architectures, let’s first clarify what we mean by a modality.

A modality is a particular form or type of data characterized by its structure and the way it conveys information. Common modalities include:

  • Text: Sequences of words or tokens.
  • Images: Two-dimensional arrays of pixel values.
  • Audio: Time-series data representing sound waves.
  • Video: Sequences of images over time.
  • Sensor Data: Measurements from devices like accelerometers or temperature sensors.

In your previous studies, you encountered convolution as a sparse reasoning mechanism that excels at incorporating spatial locality into model predictions. This method, which shines in handling image data, allows us to transform pixels into meaningful features through feature extraction, which can then be used to condition a dense net to serve as a classifier.

Similarly, every data type (or modality) has inherent relationships that govern how its pieces of information interact, and different architectures are needed to handle them effectively. For instance, we’ve been focusing on transformers, with their token-wise dense layers and efficient attention mechanisms, as excellent reasoners for language.

This is because language is fundamentally an ordered sequence, and attention is particularly well-suited for reasoning over sequences. But it’s important to remember that the transformer’s attention mechanism is not limited to just processing words and sentences — it’s versatile enough to be applied to a wide range of modalities and can be combined with other structure-reasoning modules as necessary.

image Source: NVIDIA Keynote at SIGGRAPH 2023 | NVIDIA

In this lab, we will be primarily dealing with the following modalities:

  • Natural Language: Structured as an ordered sequence of tokens, each carrying semantic meaning within a context.
  • Images: Comprised of pixel values that can capture both micro (fine-detail) and macro (big-picture) aspects in two dimensions.

Though each modality comes with its unique challenges and architectures, they share a common goal when used in multimodal systems: to communicate information and help convert it into usable representations (whether explicit or implicit).

Part 2: Encoding Different Modalities

Transformers excel at capturing relationships between tokens in a sequence, making them highly effective for a variety of tasks such as text classification, generation, and translation. This success with text naturally leads us to explore whether transformers can handle other data types, or modalities, that have their own structures and patterns.

Embedding Text

Below is how to encode several strings of text which will be used later as experimental samples:

from transformers import BertTokenizer, BertModel
import torch

text_captions = ("Cat with paint", "rock statue", "frogs on leaf", "jellyfish")
text_dialogue = ("Cats don't usually", 'like water', 'but this one likes paint', 'quite a lot')

# Load model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer(text_captions + text_dialogue, padding=True, return_tensors=model.framework)
num_values = inputs.get("attention_mask").sum(axis=1)

# Get text embeddings
with torch.no_grad():
    outputs = model(**inputs)
text_embeddings = outputs.last_hidden_state  # (batch_size, sequence_length, hidden_size)
print(f"{text_embeddings.shape = }")

captions_embeddings = text_embeddings[:4]
dialogue_embeddings = text_embeddings[4:]

We now have eight sequences of embeddings, each corresponding to a text string in our inputs. Each token in the sequence has a 768-dimensional embedding capturing some hopefully-useful contextual and semantic information.

Embedding Images

Now let’s turn to images, which are fundamentally different from text and audio. Images are 2D arrays of pixel values, so we need a way to convert these into a sequence that a transformer can process. This is where the Vision Transformer (ViT) formulation comes into play.

The ViT model treats an image as a sequence of patches. Each patch is mapped into some reasonable learned vector representation (and is usually flattened at some point in current formulations), and these vectors are treated like tokens in a text sequence. The transformer then learns to capture relationships between different patches, allowing it to understand both local patterns (within patches) and global context (across the whole image).

image Source: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (2020)

Let’s see how we might encode an image using a ViT-enabled encoder model like google/vit-base-patch16-224:

from transformers import ViTImageProcessor, ViTModel
from PIL import Image
import requests

# Load the pre-trained feature extractor and model
feature_extractor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')

# Load an example image
img_files = ["paint-cat", "rock-head", "tree-frog", "two-jelly"]
images = [Image.open(f"img-files/{name}.jpg") for name in img_files]    

# Preprocess the image to fit model input
inputs = feature_extractor(images=images, return_tensors=model.framework)

# Forward pass through the model to get image features
with torch.no_grad():
    outputs = model(**inputs)
image_hidden_states = outputs.last_hidden_state
print(image_hidden_states.shape)  # Shape: (batch_size, num_patches, hidden_size)

In this case, the ViT component divides the image into patches, processes them through the transformer, and outputs a sequence of encoded patches. Each patch now contains contextual information about its neighboring patches, allowing the model to understand the image in a structured way.

Part 3: Joint Projections

We now know that using transformers to obtain semantically dense embeddings of text, and images is sensible. Regardless of the underlying structure of the data, we were able to extract some kind of representation with a 768-dimensional embedding dimension. Even so, we have some discrepancies to resolve:

  • The language embedding is still a series of per-token embeddings.
  • The image embedding is a per-patch embedding.

Since we have an attention interface in each of these embedders, we know that each token has global context for the sequence. In theory, we should favor using the [CLS] token if one is provided due to its persistence in the training data, but for this example we will ignore it and just average the embedding vectors. These weren’t jointly optimized and will give more erratic similarity results then correlated value averages.

Let’s extract average embeddings for each modality to obtain a single vector representation per input:

## Eliminate contribution from pad tokens, since these values were obtained via batching
captions_avg_embeds = torch.sum(captions_embeddings[:,1:,:], axis=1).detach() / (num_values[:4].view(-1, 1) - 1)
dialogue_avg_embeds = torch.sum(dialogue_embeddings[:,1:,:], axis=1).detach() / (num_values[4:].view(-1, 1) - 1)

## Remove CLS token from ViT, since we said we just wanted embedding average
image_avg_embeds = torch.mean(image_hidden_states.view(4, -1, 768)[:,1:,:], axis=1).detach()


print(f"{captions_avg_embeds.shape = }")
print(f"{dialogue_avg_embeds.shape = }")
print(f"{image_avg_embeds.shape = }")

We now have a single 768-dimensional embedding for each input across all modalities. However, these embeddings are derived from models that were trained independently and are not aligned in any shared space. To illustrate this, let’s compare the similarity between embeddings from different modalities:

import torch
import seaborn as sns
import matplotlib.pyplot as plt

def plot_similarity(similarity, xlab, ylab, xticks, yticks, ax):
    # Visualization of the similarity matrix on the provided axis
    sns.heatmap(similarity.numpy(), annot=True, cmap='coolwarm', xticklabels=xticks, yticklabels=yticks, ax=ax)
    ax.set_title(f"Similarity between {xlab} and {ylab} Embeddings")
    ax.set_xlabel(xlab)
    ax.set_ylabel(ylab)

fig, axs = plt.subplots(2, 2, figsize=(18, 12))

## Expected Heatmap For Top Row: Diagonal Matrix. Plot [0][0] is a demo of this
sim_mtx = (image_avg_embeds @ image_avg_embeds.T).softmax(dim=0)
plot_similarity(sim_mtx, "Image", "Image", img_files, img_files, axs[0][0])

sim_mtx = (captions_avg_embeds @ image_avg_embeds.T).softmax(dim=0)
plot_similarity(sim_mtx, "Captions", "Image", text_captions, img_files, axs[0][1])

## Expected Heatmap For Bottom Row: Undefined, but [1][2] doesn't look terrible...
sim_mtx = (dialogue_avg_embeds @ captions_avg_embeds.T).softmax(dim=0)
plot_similarity(sim_mtx, "Dialogue", "Captions", text_dialogue, text_captions, axs[1][0])

sim_mtx = (dialogue_avg_embeds @ image_avg_embeds.T).softmax(dim=0)
plot_similarity(sim_mtx, "Dialogue", "Image", text_dialogue, img_files, axs[1][1])

plt.tight_layout()
plt.show()

If you look hard, you’ll notice that there are still some sensical relationships on display. The strong alignment of some input embeddings with some output embeddings could be indicative of actual synergy (i.e. possibly bottom-right), but could also be completely nonsensical (possibly bottom-left). This is because these representations are independently optimized for downstream learning, thereby forming completely different implicit representations.

Joint Optimization with CLIP

To address this limitation, models like CLIP (Contrastive Language–Image Pre-training) are trained specifically to project different modalities — like images and text — into a shared embedding space in which they can be easily compared. This shared space is optimized jointly across both modalities so that the model learns representations that are mutually consistent, i.e. similar images and text are mapped closer together while unrelated pairs are pushed apart.

CLIP does this through a contrastive learning objective, which aligns the embeddings of paired images and their captions by minimizing the distance between them in the shared space while maximizing the distance between unrelated pairs. This joint optimization process encourages the model to learn useful modular representations that can facilitate that multimodal tasks like image-text retrieval.

Source: Learning Transferable Visual Models From Natural Language Supervision (2021)

To demonstrate the power of joint optimization, let’s explore how CLIP embeddings perform on a task it’s actually designed for — aligning images with their corresponding text descriptions. This task is one of CLIP’s core strengths, and we’ll see how well it handles the alignment of images and captions in its shared embedding space.

from transformers import CLIPProcessor, CLIPModel

text_captions = ("Cat with paint", "rock statue", "frogs on leaf", "jellyfish")
text_dialogue = ("Cats don't usually", 'like water', 'but this one likes paint', 'quite a lot')

## TODO: Load in a clip model of choice, reading over the model-card recommendations.
## HINT: We'd recommend openai/clip-vit-base-patch32
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

#######################################################################
## TODO: Compute the text and image embeddings for analysis
inputs_text = processor(text=text_captions + text_dialogue, return_tensors="pt", padding=True)
inputs_images = processor(images=images, return_tensors="pt", padding=True)

#######################################################################
## TODO: Get the text and image embeddings for final visual
with torch.no_grad():
    text_embeddings = model.get_text_features(**inputs_text)
    image_embeddings = model.get_image_features(**inputs_images)

fig, axs = plt.subplots(1, 2, figsize=(12, 6))

sim_mtx = (text_embeddings[:4] @ image_embeddings.T).softmax(dim=0)
plot_similarity(sim_mtx, "Captions", "Image", text_captions, img_files, axs[0])

sim_mtx = (text_embeddings[4:] @ image_embeddings.T).softmax(dim=0)
plot_similarity(sim_mtx, "Dialogue", "Image", text_dialogue, img_files, axs[1])

plt.tight_layout()
plt.show()


Part 4: Text-Guided Image Diffusion

We now have two different ways of progressively generating non-trivial modalities, autoregressive and iteratively-refined. We’ve previously established that autoregression is great for sequence outputs, but have now posed action policy prediction as a contentious example which has multiple competing approaches depending on the desired outcome and reasoning perspective.

Pulling In A Diffusion Model

For output domains where the transformer or autoregressive formulation itself is insufficient to generate good-quality results (i.e., representations that aren’t natural sequences), we may need specialized decoders that are better-suited for the structure of your output modality. One especially-notable approach is to transition to progressive denoising refinement, where the goal is to transition information towards the final output progressively as opposed to accumulating the output one finalized building block at a time. This is how the diffusion model work:

The following code block allows you to pull in one of the state-of-the-art image diffusion models, StabilityAI’s Stable Diffusion XL (SDXL-1.0) model. It sits alongside other popular diffusion-based image generators like OpenAI’s Dalle models models and operates with a simple prompt-based API as follows:

from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16",
).to("cuda")

prompt = "An astronaut riding a green horse"

images = pipe(prompt=prompt).images[0]
images.show()