20. Multi-Modal Training

Chapter 20 of 24 · 15 min

Training multimodal models presents unique challenges in balancing information from different modalities. Gradient interference between modalities can destabilize training or degrade performance on under-represented modalities.

Late fusion architectures train modalities independently and combine predictions at the decision layer. This approach simplifies training but limits the model's ability to learn cross-modal representations. Early fusion concatenates raw inputs and trains jointly, enabling cross-modal learning but increasing optimization difficulty.

class MultiModalTransformer(nn.Module):
    def __init__(self, video_dim, audio_dim, hidden_dim, num_heads, num_layers):
        super().__init__()
        self.video_proj = nn.Linear(video_dim, hidden_dim)
        self.audio_proj = nn.Linear(audio_dim, hidden_dim)
        self.pos_encoder = PositionalEncoding(hidden_dim)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer, 
            num_layers=num_layers
        )
        
        self.classifier = nn.Linear(hidden_dim, num_actions)
    
    def forward(self, video_features, audio_features):
        # Project to common space
        v = self.video_proj(video_features)
        a = self.audio_proj(audio_features)
        
        # Concatenate along sequence dimension
        combined = torch.cat([v, a], dim=1)
        combined = self.pos_encoder(combined)
        
        # Joint encoding
        encoded = self.transformer(combined)
        
        # Pool and classify
        return self.classifier(encoded.mean(dim=1))

Gradient balancing prevents modality dominance. When one modality produces larger gradients, it can overwhelm updates to the other modality's representations. Techniques include gradient normalization, modality-specific learning rates, and gradient orthogonality constraints.

Contrastive learning provides an alternative training objective that learns aligned representations across modalities. Video and audio embeddings trained with contrastive loss will cluster semantically similar content together, enabling zero-shot transfer between modalities.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Implement a late fusion vs. early fusion comparison experiment. Train both architectures on a small multimodal dataset and compare convergence behavior and final performance.