Training MoE Right: Making Every Expert Count
Mixture-of-Experts models activate only a fraction of their parameters per token. The hard part is making sure each expert actually learns something useful.
The Expert Collapse Problem
The core idea of Mixture-of-Experts (MoE) is simple: instead of one giant FFN layer, use $N$ smaller FFN "experts" and a learned router that sends each token to the best $k$ of them. This gives you a model with $N \times$ the parameters but only $k \times$ the per-token compute.
The problem is that routing is learned jointly with the experts, and the learning dynamics conspire against you. A few failure modes dominate:
- Expert collapse. The router learns to send most tokens to 1-2 experts. Those experts get all the gradient signal, improve fastest, and attract even more tokens. The remaining experts starve and never learn. This is a rich-get-richer positive feedback loop.
- Uniform routing. The opposite extreme: the router hedges by sending tokens roughly equally to all experts. No expert ever specializes. The model degenerates into an expensive dense model.
- Representation collapse. Multiple experts converge to compute the same function. You pay for $N$ experts but get the effective capacity of far fewer.
Every technique in this post addresses one or more of these failure modes. The goal is always the same: encourage the router to spread tokens across experts so that each expert specializes on a meaningful subset of the input distribution.
How Routing Works
The router is typically a single linear layer that takes a token's hidden state $\mathbf{h} \in \mathbb{R}^d$ and produces logits over $N$ experts:
where $\mathbf{g} \in \mathbb{R}^N$ is a score for each expert. The router then selects the top-$k$ experts and computes gating weights using a softmax over the selected logits:
$G_i$ is the gating weight for expert $i$. The MoE layer output is the weighted sum of the selected experts' outputs:
where $E_i(\mathbf{h})$ is expert $i$'s FFN output. Since most $G_i$ are zero, only $k$ FFN forward passes actually run.
Load Balancing Losses
The oldest and most common fix: add an auxiliary loss that penalizes uneven expert utilization. The idea appeared in GShard and was refined by the Switch Transformer.
Switch Transformer's Load Balancing Loss
For a batch of $T$ tokens and $N$ experts, define two quantities:
$f_i$ is the fraction of tokens for which expert $i$ was the top choice. This is a hard count.
$p_i$ is the average softmax probability assigned to expert $i$ across all tokens. This is a soft, differentiable quantity.
The auxiliary loss is their dot product, summed over experts:
$\alpha$ is a small coefficient, typically 0.01. $N$ is a scaling factor that makes the loss value independent of the number of experts.
Why use both $f_i$ and $p_i$ instead of just one? $f_i$ involves an argmax, so it is not differentiable. $p_i$ is differentiable but does not capture the actual routing decision (a token with $p_i = 0.49$ for expert A and $0.51$ for expert B still routes to B). Their product combines the hard routing reality ($f_i$) with a differentiable signal ($p_i$) that gradients can flow through.
Why Not KL Divergence from Uniform?
A natural alternative: measure how far the routing distribution is from uniform using KL divergence. If the target is a uniform distribution $u_i = 1/N$, the KL divergence from uniform is:
where $H(p) = -\sum_i p_i \log p_i$ is the entropy of $p$. Minimizing $D_{KL}(p \| u)$ is equivalent to maximizing the entropy of the router's soft distribution. This is well-defined and differentiable.
Worked example: KL divergence vs. load balancing loss
Consider $N = 4$ experts and $T = 8$ tokens. Suppose the router produces these soft probabilities and hard routing decisions:
| Token | $p_{E_1}$ | $p_{E_2}$ | $p_{E_3}$ | $p_{E_4}$ | Routed to |
|---|---|---|---|---|---|
| $t_1$ | 0.50 | 0.30 | 0.15 | 0.05 | $E_1$ |
| $t_2$ | 0.45 | 0.35 | 0.10 | 0.10 | $E_1$ |
| $t_3$ | 0.40 | 0.25 | 0.20 | 0.15 | $E_1$ |
| $t_4$ | 0.10 | 0.55 | 0.20 | 0.15 | $E_2$ |
| $t_5$ | 0.15 | 0.50 | 0.25 | 0.10 | $E_2$ |
| $t_6$ | 0.20 | 0.20 | 0.45 | 0.15 | $E_3$ |
| $t_7$ | 0.48 | 0.22 | 0.18 | 0.12 | $E_1$ |
| $t_8$ | 0.42 | 0.28 | 0.20 | 0.10 | $E_1$ |
Hard routing fractions (count how many tokens each expert actually receives):
$f_1 = 5/8 = 0.625, \quad f_2 = 2/8 = 0.25, \quad f_3 = 1/8 = 0.125, \quad f_4 = 0/8 = 0.0$
$E_4$ is a dead expert: zero tokens routed to it.
Average soft probabilities (average each column):
$p_1 = 0.338, \quad p_2 = 0.331, \quad p_3 = 0.216, \quad p_4 = 0.115$
Now compute both losses:
KL divergence from uniform (using $p_i$ only):
$D_{KL}(p \| u) = 0.338 \ln(4 \cdot 0.338) + 0.331 \ln(4 \cdot 0.331) + 0.216 \ln(4 \cdot 0.216) + 0.115 \ln(4 \cdot 0.115)$
$= 0.338(0.302) + 0.331(0.281) + 0.216(-0.145) + 0.115(-0.768)$
$= 0.102 + 0.093 - 0.031 - 0.088 = \mathbf{0.076}$
This looks mild. The soft distribution $p$ is not that far from uniform: $[0.338, 0.331, 0.216, 0.115]$ vs. $[0.25, 0.25, 0.25, 0.25]$.
Load balancing loss (using $f_i \cdot p_i$):
$N \sum_i f_i \cdot p_i = 4 \times (0.625 \cdot 0.338 + 0.25 \cdot 0.331 + 0.125 \cdot 0.216 + 0.0 \cdot 0.115)$
$= 4 \times (0.211 + 0.083 + 0.027 + 0.0) = 4 \times 0.321 = \mathbf{1.284}$
Under perfect balance, $f_i = p_i = 0.25$, and the loss would be $4 \times 4 \times (0.25 \times 0.25) = 1.0$. So 1.284 vs. 1.0: a 28% penalty. This captures the severity of the imbalance much better.
The critical difference: the KL divergence sees only the soft probabilities, which look almost reasonable. It completely misses that $E_4$ receives zero tokens and $E_1$ receives 62.5% of them. The load balancing loss, by multiplying with the hard fractions $f_i$, directly penalizes the actual routing imbalance. The $0.625 \times 0.338$ term for $E_1$ dominates the loss, correctly flagging the expert that is hoarding tokens.
In short: KL divergence of the soft distribution is a valid regularizer, but it is blind to the hard routing decisions. The soft probabilities can look balanced while the actual routing (after the argmax) is severely skewed. The $f_i \cdot p_i$ formulation is a practical compromise: it uses the hard counts to identify which experts are actually overloaded, and the soft probabilities to provide the differentiable gradient signal needed to fix it.
The $\alpha$ Trade-off
The coefficient $\alpha$ controls how strongly the model is pushed toward balanced routing:
- Too small ($\alpha < 0.001$): the language modeling loss dominates, experts collapse.
- Too large ($\alpha > 0.1$): perfect balance, but the router cannot specialize experts on different token types. Model quality drops because the auxiliary loss overrides the main training signal.
- Sweet spot: typically $\alpha \in [0.001, 0.01]$. Switch Transformer uses $0.01$; ST-MoE uses $0.001$.
Capacity Factor and Token Dropping
Load balancing losses are soft. Capacity factor is a hard constraint: each expert has a buffer that can hold at most $C$ tokens per batch. The capacity factor $\text{CF}$ defines this buffer size:
$T$ is the total tokens in the batch, $k$ is the number of experts per token, $N$ is the total number of experts. The term $T \times k / N$ is the expected number of tokens per expert under perfectly uniform routing. $\text{CF}$ is a multiplier (typically 1.0 to 1.5) that provides headroom.
When an expert's buffer is full, additional tokens routed to it are dropped: they skip the MoE layer entirely and pass through via the residual connection. This is a form of hard enforcement: no matter what the router wants, an expert cannot become a bottleneck.
Token dropping has a real cost: dropped tokens get no expert processing. At CF = 1.0, even mild imbalance causes drops. At CF = 1.5, you waste 50% of buffer space on padding. GShard and Switch Transformer both use this mechanism, with CF typically 1.0-1.25.
Router Z-Loss
ST-MoE (Zoph et al., 2022) identified a separate instability: router logits can grow unboundedly during training, causing numerical instability in the softmax and training crashes. Their fix is the router z-loss:
The log-sum-exp (LSE) is not an arbitrary choice here. It is the log of the softmax partition function $Z = \sum_i e^{g_i}$, and it has a specific meaning: it is a smooth, differentiable approximation to the maximum logit.
$$\max(g_1, \ldots, g_N) \;\leq\; \underbrace{\log \sum_{i=1}^{N} e^{g_i}}_{\text{LSE}} \;\leq\; \max(g_1, \ldots, g_N) + \log N$$When one logit dominates (say $g_1 = 50$, all others near 0), $\text{LSE} \approx 50$. When all logits are equal at value $c$, $\text{LSE} = c + \log N$. So LSE tracks the scale of the largest logits. Penalizing $\text{LSE}^2$ directly penalizes the magnitude of the dominant logit, which is exactly what causes softmax saturation.
The connection to the softmax is direct. The softmax output is $\text{softmax}(g)_i = e^{g_i} / Z = e^{g_i - \text{LSE}}$. When $\text{LSE}$ is large because one logit is much bigger than the rest, the softmax output for that logit approaches 1.0, and the gradient $\partial \text{softmax}_i / \partial g_i = \text{softmax}_i(1 - \text{softmax}_i)$ approaches zero. The router becomes frozen: it cannot adjust its routing because the gradients vanish. The z-loss prevents this by keeping LSE (and therefore the logit scale) bounded.
ST-MoE uses a coefficient of $0.001$ for the z-loss, which is small enough to not interfere with the main loss but large enough to prevent the logit blow-up they observed. This is now standard practice: DeepSeek-V2/V3, Mixtral's follow-ups, and most recent MoE models include some form of router logit regularization.
Shared Experts
DeepSeek-V2 introduced a straightforward architectural change: reserve $K_s$ experts as shared experts that process every token, regardless of routing. The remaining $N - K_s$ experts are routed as usual.
The shared experts capture common, token-independent patterns (basic syntax, frequent co-occurrences, positional patterns). This takes pressure off the routed experts, which can focus on specialized, content-dependent processing. The result: routed experts specialize more cleanly, and the model is less sensitive to routing imbalances because the shared experts provide a reliable baseline for every token.
DeepSeek-V2 uses $K_s = 2$ shared experts alongside 160 routed experts (of which 6 are selected per token). DeepSeek-V3 continues this design with 1 shared expert and 256 routed experts.
Expert Choice Routing
Standard top-$k$ routing asks: "which experts should this token go to?" Expert Choice routing (Zhou et al., 2022) flips the question: "which tokens should this expert process?"
Each expert independently selects its top-$k'$ tokens from the batch, where $k'$ is a fixed capacity:
$\mathbf{s}_i \in \mathbb{R}^T$ is the affinity score between expert $i$ and every token. $\mathbf{H} \in \mathbb{R}^{d \times T}$ is the matrix of all token hidden states. Expert $i$ picks the $k'$ tokens with the highest affinity.
This guarantees perfect load balance by construction: every expert processes exactly $k'$ tokens. No auxiliary loss needed. No capacity overflow. No token dropping.
The trade-off: different tokens may be processed by different numbers of experts (or zero experts). A popular token might be selected by many experts; a boring token might be selected by none. Expert Choice models handle the zero-expert case through the residual connection.
Expert Choice routing has a significant practical limitation: it requires knowing all tokens in a batch before routing, making it incompatible with autoregressive generation where tokens arrive one at a time. This limits it to encoder models and training-time use.
Fine-Grained Expert Segmentation
DeepSeek-MoE (and later DeepSeek-V2/V3) introduced fine-grained expert segmentation: instead of $N$ large experts, split each expert into $m$ smaller segments, giving $N \times m$ fine-grained experts total. Then increase the top-$k$ proportionally so the total activated parameters stay the same.
For example, instead of 16 experts with top-2 routing (activating 2 large FFNs), use 64 fine-grained experts with top-8 routing (activating 8 smaller FFNs with the same total size).
The total activated parameters per token is identical: $k \times d_{\text{ffn}} = (k \times m) \times (d_{\text{ffn}} / m)$.
Why does this help?
- More flexible combinations. With 16 experts and top-2, there are $\binom{16}{2} = 120$ possible expert combinations. With 64 experts and top-8, there are $\binom{64}{8} \approx 4.4$ billion. The model can compose much finer-grained specializations.
- Better knowledge isolation. Each fine-grained expert is smaller, so it naturally learns a narrower function. Knowledge is less entangled across capabilities.
- Easier load balancing. More experts means the law of large numbers helps: natural variance in routing is smaller relative to the number of experts. The load balancing problem becomes easier without needing stronger auxiliary losses.
DeepSeek-V2 uses $m = 4$ (splitting 16 base experts into 64 fine-grained experts with top-6 routing), and reports consistent quality improvements over the coarse-grained baseline at the same compute budget.
Auxiliary-Loss-Free Load Balancing
DeepSeek-V3 takes a radical approach: drop the auxiliary loss entirely and achieve load balance through a bias term in the router that is adjusted dynamically during training.
The idea: add a per-expert bias $b_i$ to the router logits before the top-$k$ selection (but not in the gating weight calculation):
The bias $b_i$ is not learned by gradient descent. Instead, it is updated by a simple rule: if expert $i$ is overloaded in the current batch, decrease $b_i$ by a small constant $\gamma$; if underloaded, increase $b_i$ by $\gamma$.
$\gamma$ is a very small step, on the order of $0.001$. This is a form of integral control: the bias accumulates small corrections over many batches, gently steering the router toward balance without ever introducing a loss term that fights the language modeling objective.
DeepSeek-V3 reports that this approach achieves better model quality than auxiliary-loss-based methods while maintaining comparable load balance. The key advantage: there is zero tension between the language modeling loss and the balancing mechanism, because they operate in completely separate channels (gradients vs. heuristic bias updates).
Soft MoE: No Hard Routing at All
Soft MoE (Puigcerver et al., 2023) eliminates discrete routing entirely. Instead of sending individual tokens to experts, it creates soft, weighted combinations of all tokens and sends those to experts.
The mechanism works in two steps:
There is no top-$k$ selection, no discrete routing decision, and therefore no load balancing problem at all. Every expert always processes the same number of slots, and every token contributes to (and receives from) every expert through soft weights.
The trade-off is computational: Soft MoE requires attention-like operations (the dispatch and combine matrices involve all-to-all token interactions), which partially negates the efficiency advantage of sparse MoE. It also cannot easily scale to very large expert counts because the dispatch weights become a bottleneck.
Soft MoE has been primarily demonstrated in vision models (ViT-based). Its application to autoregressive language models is limited because the dispatch step requires access to all tokens simultaneously, which conflicts with causal generation.
Comparison Table
| Technique | How it works | Removes aux loss? | Autoregressive? | Used in |
|---|---|---|---|---|
| Load balancing loss | Auxiliary loss penalizing $f_i \cdot p_i$ | No (is the aux loss) | Yes | Switch, GShard, Mixtral |
| Capacity factor | Hard buffer limit per expert | No (used alongside) | Yes | Switch, GShard |
| Router z-loss | Penalizes large router logits | No (adds another aux) | Yes | ST-MoE, DeepSeek-V2/V3 |
| Shared experts | Always-on experts + routed experts | No (orthogonal) | Yes | DeepSeek-V2/V3 |
| Expert Choice | Experts select tokens (not vice versa) | Yes | No | EC (Zhou et al., 2022) |
| Fine-grained experts | Split experts into smaller segments | No (orthogonal) | Yes | DeepSeek-MoE/V2/V3 |
| Aux-loss-free bias | Dynamic bias on router logits | Yes | Yes | DeepSeek-V3 |
| Soft MoE | Soft dispatch, no discrete routing | Yes (no routing) | No | Soft MoE (Puigcerver, 2023) |
Practical Notes
What most production systems actually use
Modern MoE LLMs (Mixtral, DeepSeek-V3, Grok, DBRX) typically combine several of these techniques rather than relying on any single one. A common recipe:
- Top-$k$ token-choice routing ($k$ = 2 for Mixtral, $k$ = 8 for DeepSeek-V3 with fine-grained experts)
- Some form of load balancing (auxiliary loss or DeepSeek's bias method)
- Router z-loss for training stability
- Shared experts (DeepSeek family)
How many experts?
There is no consensus. Mixtral uses 8 experts with top-2 (modest sparsity). DeepSeek-V3 uses 257 (1 shared + 256 routed) with top-8 out of 256 routed (high sparsity). More experts generally means more total parameters at the same compute cost, but the routing problem gets harder and communication overhead grows in distributed training.
Expert parallelism is the real bottleneck
The dominant practical concern is not routing quality but communication cost. In a distributed setup, each expert lives on a different device. Routing tokens to their assigned experts requires all-to-all communication across devices, which is expensive and latency-sensitive. This is why capacity factors and token dropping exist: they bound the worst-case communication load. DeepSeek-V3 addresses this with device-constrained routing, where a token can only be routed to experts on a limited set of devices, reducing the all-to-all communication to a smaller group.
MoE and fine-tuning
Fine-tuning MoE models requires care. If the fine-tuning data is small or narrowly distributed, the router may shift to routing all tokens to a few experts, and the unused experts degrade. Common mitigations: freeze the router during fine-tuning (only train expert FFNs), use a higher $\alpha$ for load balancing during fine-tuning, or use LoRA on expert weights to limit the parameter changes.
Evaluating expert utilization
Track these metrics during training to catch problems early:
- Expert load standard deviation. The standard deviation of $f_i$ across experts. Should be small relative to $1/N$.
- Dead expert rate. The fraction of experts with $f_i < \epsilon$ for some threshold $\epsilon$ (say, $0.001$). Should be zero.
- Router entropy. The entropy of the router's softmax distribution, averaged over tokens. Low entropy means the router is very confident (possibly collapsed); high entropy means it is hedging (possibly not specializing).
- Expert output similarity. Cosine similarity between the outputs of different experts on the same input. High similarity means representation collapse: experts are computing the same thing.