Découvrez comment Mixture of Depths (MoD) optimise l'efficacité de l'IA en acheminant dynamiquement les jetons. Découvrez comment cette technique réduit les FLOP dans Ultralytics et les LLM.
Dans les architectures d'apprentissage profond, l'efficacité computationnelle est primordiale, en particulier lors du traitement de longues séquences ou d' entrées à haute résolution. Une nouvelle approche alloue dynamiquement les ressources de calcul en permettant au réseau de décider quelles parties de l'entrée nécessitent un traitement complet et lesquelles peuvent contourner certaines couches en toute sécurité. Cette stratégie de routage dynamique réduit la complexité computationnelle globale sans sacrifier la puissance prédictive ou la précision du modèle.
Le mélange de profondeurs (MoD) est une technique architecturale principalement appliquée aux architectures Transformer, dans lesquelles le modèle apprend à ignorer de manière dynamique le calcul de tokens spécifiques à différents niveaux. Les transformateurs traditionnels traitent chaque token à chaque niveau, qu'il s'agisse d'une information cruciale ou d'un contenu de remplissage. En revanche, un modèle MoD utilise un mécanisme de routage pour évaluer les tokens et leur attribuer une note. Seuls les tokens ayant obtenu les meilleurs scores, dans la limite d'une capacité prédéfinie, sont soumis à des blocs de calcul lourds, tels que des mécanismes d'attention ou des couches denses à propagation directe . Les tokens restants contournent le bloc via des connexions résiduelles, créant ainsi un « mélange de profondeurs » où différents tokens sont soumis à des niveaux de profondeur de traitement variables.
Cette méthode, popularisée par les récentes recherches de DeepMind et largement documentée dans le référentiel arXiv, réduit considérablement le nombre total d'opérations en virgule flottante (FLOP) nécessaires pendant l' apprentissage et l'inférence.
Il est facile de confondre ce concept avec celui de Mixture of Experts (MoE). Bien que les deux utilisent des mécanismes de routage, ils résolvent des problèmes différents :
La capacité à budgétiser dynamiquement les calculs rend cette technique très précieuse dans de nombreux domaines de la vision par ordinateur et du traitement du langage naturel.
Vous trouverez ci-dessous un PyTorch conceptuel illustrant comment un mécanisme de routage de base peut ignorer le calcul d'une partie des jetons d'entrée, simulant ainsi un comportement de routage en profondeur.
import torch
import torch.nn as nn
class MixtureOfDepthsBlock(nn.Module):
def __init__(self, d_model, capacity_factor=0.5):
super().__init__()
self.capacity_factor = capacity_factor
self.router = nn.Linear(d_model, 1)
self.heavy_compute = nn.Sequential(nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model))
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
seq_len = x.size(1)
capacity = int(seq_len * self.capacity_factor)
# 1. Compute routing scores
scores = self.router(x).squeeze(-1) # Shape: (batch_size, seq_len)
# 2. Identify top-k tokens to process
topk_indices = torch.topk(scores, capacity, dim=1).indices
# 3. Create an output tensor mirroring the input (residual baseline)
output = x.clone()
# 4. Apply heavy computation only to the selected tokens
for b in range(x.size(0)):
selected_tokens = x[b, topk_indices[b]]
processed_tokens = self.heavy_compute(selected_tokens)
output[b, topk_indices[b]] += processed_tokens
return output
# Example usage
dummy_input = torch.randn(2, 64, 128) # Batch=2, Seq=64, Dim=128
mod_block = MixtureOfDepthsBlock(d_model=128, capacity_factor=0.5)
output = mod_block(dummy_input)
print(f"Output shape: {output.shape}") # Expect (2, 64, 128)
En tirant parti de frameworks tels que PyTorch ou TensorFlow, les développeurs peuvent intégrer ces blocs d'optimisation de modèles personnalisés. De plus, des outils tels que la Ultralytics aident les équipes à gérer les données d'entraînement nécessaires pour former avec précision ces routeurs, tout en s'intégrant de manière transparente aux écosystèmes d'entreprise tels que Google AI.