Mixture‑of‑Experts model - chunhualiao/public-docs GitHub Wiki
Mixture‑of‑Experts (MoE) – A Detailed Walk‑through
(focus on modern sparsely‑gated MoE used in large language models, but the principles apply to any MoE architecture)
1. What is a Mixture‑of‑Experts?
A Mixture‑of‑Experts (MoE) is a neural‑network architecture that combines multiple specialised sub‑networks (“experts”) with a learned gating mechanism that decides which expert(s) should process each input (or part of an input).
| Component | Role | Typical Implementation |
|---|---|---|
| Experts | Parameter‑rich feed‑forward (FFN) sub‑layers, often identical in structure (e.g., a 2‑layer MLP) but with independent weights. | Expert_i(x) = W2_i·σ(W1_i·x) (σ = ReLU, GELU, etc.) |
| Gate (Router) | Computes a probability distribution over the expert set for an input token (or token‑group). | g(x) = softmax(U·x) where U is a small projection matrix |
| Load‑Balancing / Auxiliary Losses | Encourage the gate to use all experts fairly and avoid collapsed routing. | “Importance” loss, “load” loss, “capacity” loss, etc. |
| Sparsity | Only a few experts (often 1–2) are activated per token → huge parameter count with modest compute. | Top‑k selection on the gate scores (k = 1 or 2). |
The key idea: increase model capacity (parameter count) without linearly increasing FLOPs or latency by sparsely activating a tiny subset of the total parameters for each forward pass.
2. Architecture Blueprint (Transformer‑style MoE)
Below is the canonical way MoE is embedded inside a Transformer block (e.g., the Switch Transformer or GLaM).
Input token embeddings → Multi‑Head Self‑Attention (MHSA) → Add & Norm
│
└─► MoE Feed‑Forward Layer (sparse) ─► Add & Norm
(gate + 1–2 experts per token)
- MHSA is dense as usual (O(N²) attention).
- MoE FFN replaces the usual dense FFN.
- The gate receives the output of the attention sub‑layer (or the embedding directly) and decides which expert(s) to invoke.
3. Detailed Inference Flow (Per Token)
Assume a batch of B sequences, each of length L. For simplicity, treat each token individually; frameworks typically vectorise across the batch.
3.1 Forward Pass Through the Gate
-
Project token representation [ s = U , x \qquad (U \in \mathbb{R}^{E \times d}) ] where
E= number of experts,d= hidden dimension. -
Compute raw scores (logits) for each expert: [ \text{logits}_i = s_i \qquad i = 1\ldots E ]
-
Optional scaling / temperature: [ \tilde{s}_i = \frac{s_i}{\tau} ] Lower
τmakes the distribution sharper. -
Select top‑k experts (
kis usually 1 or 2).- Use
torch.topkor an equivalent to extract indicesidxand scoresscore_topk. - Sparsify: set all non‑top‑k scores to
-infand applysoftmaxonly on the selectedk, obtaining routing probabilitiesp_i.
probs = softmax(score_topk) # shape: (B*L, k) experts_idx = idx # shape: (B*L, k) - Use
3.2 Dispatch (Routing) Tokens to Experts
Because each expert processes a variable number of tokens, we need a dispatch step:
- Create a “dispatch mask” of shape
(B*L, E)wheremask[b,l,i] = p_iif expertiis selected, otherwise 0. - Optionally clamp per‑expert capacity (e.g., max tokens per expert). Overflowed tokens can be dropped or rerouted to a “fallback” expert.
3.3 Expert Computation (Vectorised)
Each expert i receives a compact tensor of its assigned tokens:
expert_input_i = gather(input, where(experts_idx == i))
output_i = Expert_i(expert_input_i) # standard dense MLP
If the framework uses Mixture of Experts kernels (e.g., NVIDIA’s torch.nn.Moe or Google’s tf.moe), the gather/scatter is handled automatically on GPU/TPU.
3.4 Combine (Reduce) Results
After each expert returns its output, we scatter them back:
output[b,l] = Σ_i probs[b,l,i] * expert_output_i[corresponding token]
Because each token may have up to k contributions, we simply sum (or average) the weighted outputs.
3.5 Continue Downstream
The MoE output proceeds through the Add‑&‑Norm residual connection and into the next Transformer's layer.
4. Training the MoE
Training introduces several nuances beyond the inference flow:
4.1 Objective
Standard language‑model loss (e.g., cross‑entropy) is unchanged; MoE components are trained end‑to‑end via back‑propagation, just like any other layer.
4.2 Gradient Flow Through the Gate
- Straight‑through estimator: The gate’s top‑k selection is discrete. During the backward pass, the gradient w.r.t. the gating logits is typically approximated by treating the soft selection as if it were continuous (i.e., propagate through the softmax over the selected
k). - Gumbel‑Softmax / Noisy‑Top‑k: Some early MoE papers added Gumbel noise to logits before top‑k, turning the gating into a stochastic but differentiable operation. Modern large‑scale MoE (Switch, GLaM) often uses deterministic top‑k with a small auxiliary loss instead of noise.
4.3 Load‑Balancing and Auxiliary Losses
A naïve gate can collapse—all tokens go to a few experts, leaving most weights idle. To prevent this, MoE models augment the loss with regularisers that encourage balanced expert usage.
4.3.1 Basic "Importance" Loss
Define:
C_i= total routing probability mass assigned to expertiacross the batch.E_i= total number of tokens actually dispatched to experti(capacity usage).
The importance loss penalises deviation from uniformity:
[ \mathcal{L}{\text{importance}} = \sum{i=1}^{E} \left( \frac{C_i}{\sum_j C_j} - \frac{1}{E} \right)^2 ]
4.3.2 "Load" Loss
Similarly, penalise uneven token count per expert:
[ \mathcal{L}{\text{load}} = \sum{i=1}^{E} \left( \frac{E_i}{\sum_j E_j} - \frac{1}{E} \right)^2 ]
4.3.3 Combined Regulariser
Typical practice (e.g., Switch Transformer) multiplies each component by a coefficient λ ≈ 0.01 and adds them to the main loss:
[ \mathcal{L} = \mathcal{L}{\text{CE}} + λ{\text{imp}} \mathcal{L}{\text{importance}} + λ{\text{load}} \mathcal{L}_{\text{load}} ]
4.4 Capacity Factor & Token Dropping
Each expert has a capacity C = capacity_factor * (batch_tokens / num_experts).
If more tokens are routed to an expert than its capacity, the excess tokens are dropped (or sent to a fallback expert). Dropped tokens are typically zero‑gradient for that expert, which can be interpreted as a regularisation effect.
4.5 Expert Parallelism (Data/Model Parallel)
Because each expert is independent, MoE training is model‑parallel friendly: experts can be sharded across multiple devices. The common strategy:
- Expert parallelism: partition the expert dimension across devices (e.g., each GPU hosts
N_expert / num_devicesexperts). - Data parallelism per device: each device holds a replica of the remaining model parameters (attention layers, gating weights, token embeddings).
- All‑to‑all communication is required after the gate to exchange token batches so that each device receives the tokens assigned to its experts. Modern frameworks (e.g., Mesh Tensorflow, DeepSpeed MoE, Megatron‑LM) implement this efficiently using NCCL or NCCL‑based collective ops.
4.6 Training Schedule & Warm‑up
- Gate Warm‑up: Early in training, the gating network may be frozen or trained with a higher temperature to avoid premature specialization.
- Curriculum of Capacity: Some works slowly increase the capacity factor as training progresses, permitting the model to first learn coarse routing before fine‑grained load balancing.
4.7 Back‑propagation Through Experts
Each expert’s weight update follows the standard Adam (or LAMB, RMSProp) rule. Since most tokens bypass a given expert, gradient sparsity is natural: an expert only receives gradients from the tokens it processed.
5. Variants of MoE
| Variant | What Changes | Typical Use‑case |
|---|---|---|
| Switch Transformer | k = 1 (single expert per token) + optional “auxiliary” feed‑forward fallback expert. |
Extreme scaling (up to 1‑trillion parameters) with low inference cost. |
| GLaM (Generalist‑LAnguage Model) | k = 2 (two experts) + fine‑grained load‑balancing loss. |
Improves quality vs. Switch while still saving compute. |
| Sparse Mixture of Experts (SMoE) | Experts are different architectures (e.g., convolution, RNN, transformer). | Multi‑modal or heterogeneous input. |
| Hierarchical MoE | Multiple gating levels (e.g., “router‑router‑expert”). | Scalability to millions of experts. |
| Conditional Computation MoE | Gate conditioned not just on token features but also on task embeddings. | Multi‑task or continual‑learning setups. |
| Mixture of Depth‑wise Experts (MoDE) | Experts are very shallow (e.g., 1‑layer MLP). | Faster inference, used in some vision‑transformer variants. |
6. Inference Optimizations
-
Static Routing (Compilation)
- For latency‑critical serving, the gating decisions for a given input can be pre‑computed and baked into a static graph, eliminating the top‑k operation at runtime.
-
Expert Pruning / Caching
- Frequently‑used experts can be cached on the inference device, while rarely‑used experts are off‑loaded to slower memory (e.g., CPU or SSD).
-
Batch‑Level Token Packing
- Tokens from many sequences are packed so that each expert sees a dense batch, minimizing padding and improving GPU utilization.
-
Low‑Precision (FP8 / INT8) Quantisation
- Since each expert is a standard MLP, it can be quantised independently without breaking the routing logic.
-
Kernel Fusion
- Specialized kernels (e.g.,
MoE_dispatch,MoE_combine) fuse gather‑multiply‑scatter to reduce memory traffic, crucial on GPUs/TPUs.
- Specialized kernels (e.g.,
7. Benefits & Trade‑offs
| Pros | Cons |
|---|---|
| Parameter efficiency: billions of parameters with ~4× the FLOPs of a dense model of the same width. | Complex engineering: needs all‑to‑all communication, careful load‑balancing, and custom kernels. |
| Specialisation: experts can specialise on language domains, topics, or token types, leading to higher quality. | Inference variability: latency can fluctuate if expert loads are imbalanced. |
| Scalable training: expert parallelism lets you scale beyond a single device’s memory. | Routing overhead: gate computation + dispatch/scatter adds small extra latency compared to a pure dense layer. |
| Graceful degradation: dropping a subset of experts at inference reduces compute with modest performance loss. | Sparse gradient updates: some experts may receive few updates, potentially requiring longer training or regularisation. |
8. Pseudocode (PyTorch‑like) – End‑to‑End Forward Pass
import torch
import torch.nn as nn
import torch.nn.functional as F
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, n_expert, top_k=2, capacity_factor=1.25):
super().__init__()
self.n_expert = n_expert
self.top_k = top_k
self.capacity = None # will be computed dynamically
self.capacity_factor = capacity_factor
# --- gating network ---
self.w_gate = nn.Linear(d_model, n_expert, bias=False)
# --- experts (identical MLPs) ---
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_ff, bias=False),
nn.GELU(),
nn.Linear(d_ff, d_model, bias=False)
) for _ in range(n_expert)
])
def forward(self, x):
"""
x: (B, L, D) token representations
"""
B, L, D = x.shape
flat_x = x.view(-1, D) # (B*L, D)
# 1) Gate scores -------------------------------------------------
logits = self.w_gate(flat_x) # (B*L, E)
topk_vals, topk_idx = torch.topk(logits, self.top_k, dim=1) # (B*L, k)
# softmax over selected logits only
probs = F.softmax(topk_vals, dim=1) # (B*L, k)
# 2) Build dispatch mask (B*L, E) -------------------------------
# mask[i, e] = prob if expert e selected for token i, else 0
dispatch_mask = torch.zeros_like(logits) # (B*L, E)
dispatch_mask.scatter_(1, topk_idx, probs)
# 3) Compute per‑expert capacity -------------------------------
tokens_per_expert = dispatch_mask.sum(0) # (E,)
if self.capacity is None:
# capacity = factor * avg tokens per expert (rounded up)
avg = (B * L) / self.n_expert
self.capacity = int(self.capacity_factor * avg) + 1
# 4) Gather tokens for each expert ------------------------------
expert_inputs = []
for i, expert in enumerate(self.experts):
# indices of tokens routed to expert i
token_mask = dispatch_mask[:, i] > 0
# optionally cap
idx = torch.nonzero(token_mask, as_tuple=False).squeeze(1)
if idx.numel() > self.capacity:
idx = idx[:self.capacity] # drop overflow
# gather token representations
expert_inputs.append(flat_x[idx])
# 5) Expert forward passes --------------------------------------
expert_outputs = []
for i, (expert, inp) in enumerate(zip(self.experts, expert_inputs)):
if inp.numel() == 0:
expert_outputs.append(torch.empty(0, D, device=x.device))
else:
expert_outputs.append(expert(inp))
# 6) Scatter back (combine) ------------------------------------
out = torch.zeros_like(flat_x) # (B*L, D)
for i, out_i in enumerate(expert_outputs):
if out_i.numel() == 0:
continue
mask = dispatch_mask[:, i] > 0
idx = torch.nonzero(mask, as_tuple=False).squeeze(1)
# match the same number of tokens used after capacity truncation
idx = idx[:out_i.shape[0]]
prob = dispatch_mask[idx, i].unsqueeze(1) # (N_i, 1)
out[idx] = out[idx] + prob * out_i
return out.view(B, L, D)
Notes
- The above is illustrative; production‑grade MoE uses all‑to‑all collective communication to move tokens across GPUs rather than the Python loops.
- The capacity check implements the “drop‑overflow” rule.
- In practice you also compute the auxiliary load‑balancing loss inside the forward method.
9. Summary Checklist
-
Model definition
- Choose
E(number of experts) andk(top‑k). - Implement a lightweight gate (
Linear → softmax). - Create identical expert sub‑networks (usually a 2‑layer FFN).
- Choose
-
Training
- Use standard LM loss + load‑balancing auxiliary losses.
- Apply capacity factor and possibly expert dropout.
- Parallelize by sharding experts across devices; use all‑to‑all for token routing.
- Optionally warm‑up the gate (higher temperature, frozen gate early).
-
Inference
- Same forward pass: gate → top‑k selection → dispatch → expert compute → combine.
- Optimize with kernel fusion, static routing, expert caching, and quantisation.
-
Pros / Cons
- Huge parameter capacity with modest compute → state‑of‑the‑art LM scaling.
- Engineering complexity and routing overhead → need careful system design.
TL;DR
A Mixture‑of‑Experts model splits a huge feed‑forward block into many independent experts and learns a gating network that selects a tiny subset (usually 1‑2) of those experts per token. The gate computes scores, picks the top‑k experts, and sends each token only to the chosen experts. Training is end‑to‑end but requires auxiliary load‑balancing losses to keep all experts alive and a capacity constraint to bound per‑expert workload. Parallelism is achieved by sharding experts across devices and performing an all‑to‑all communication step to move tokens to the right device. At inference time the same sparse routing happens, but one can further optimise dispatch/combine kernels, cache hot experts, or even pre‑compute routes for latency‑critical serving. This sparse‑activation strategy lets MoE models reach trillion‑parameter scales while using a fraction of the FLOPs of an equivalently sized dense model, delivering higher quality language generation with manageable computational cost.