Training MoE Right: Making Every Expert Count

Mixture-of-Experts models activate only a fraction of their parameters per token. The hard part is making sure each expert actually learns something useful.

The Expert Collapse Problem

The core idea of Mixture-of-Experts (MoE) is simple: instead of one giant FFN layer, use $N$ smaller FFN "experts" and a learned router that sends each token to the best $k$ of them. This gives you a model with $N \times$ the parameters but only $k \times$ the per-token compute.

The problem is that routing is learned jointly with the experts, and the learning dynamics conspire against you. A few failure modes dominate:

Every technique in this post addresses one or more of these failure modes. The goal is always the same: encourage the router to spread tokens across experts so that each expert specializes on a meaningful subset of the input distribution.

Input tokens Router (softmax) Expert 1 selected Expert 2 selected Expert 3 Expert N ... Weighted sum of outputs top-k = 2

How Routing Works

The router is typically a single linear layer that takes a token's hidden state $\mathbf{h} \in \mathbb{R}^d$ and produces logits over $N$ experts:

Router logits
$$\mathbf{g} = \mathbf{W}_r \mathbf{h}, \quad \mathbf{W}_r \in \mathbb{R}^{N \times d}$$

where $\mathbf{g} \in \mathbb{R}^N$ is a score for each expert. The router then selects the top-$k$ experts and computes gating weights using a softmax over the selected logits:

Gating weights (top-k)
$$G_i = \begin{cases} \frac{e^{g_i}}{\sum_{j \in \text{TopK}} e^{g_j}} & \text{if } i \in \text{TopK}(\mathbf{g}, k) \\ 0 & \text{otherwise} \end{cases}$$

$G_i$ is the gating weight for expert $i$. The MoE layer output is the weighted sum of the selected experts' outputs:

MoE output
$$\mathbf{y} = \sum_{i=1}^{N} G_i \cdot E_i(\mathbf{h})$$

where $E_i(\mathbf{h})$ is expert $i$'s FFN output. Since most $G_i$ are zero, only $k$ FFN forward passes actually run.

Intuition: The router is a learned dispatcher. It looks at each token and decides which $k$ specialists should handle it. The gating weight $G_i$ controls how much each specialist's answer contributes to the final output. The catch: this dispatcher is trained from scratch alongside the specialists, and nothing in the base loss function prevents it from always calling the same specialist.

Load Balancing Losses

The oldest and most common fix: add an auxiliary loss that penalizes uneven expert utilization. The idea appeared in GShard and was refined by the Switch Transformer.

Switch Transformer's Load Balancing Loss

For a batch of $T$ tokens and $N$ experts, define two quantities:

Fraction of tokens routed to expert $i$
$$f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbf{1}[\text{argmax}(\mathbf{g}^{(t)}) = i]$$

$f_i$ is the fraction of tokens for which expert $i$ was the top choice. This is a hard count.

Average router probability for expert $i$
$$p_i = \frac{1}{T} \sum_{t=1}^{T} \text{softmax}(\mathbf{g}^{(t)})_i$$

$p_i$ is the average softmax probability assigned to expert $i$ across all tokens. This is a soft, differentiable quantity.

The auxiliary loss is their dot product, summed over experts:

Load balancing loss
$$\mathcal{L}_{\text{bal}} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot p_i$$

$\alpha$ is a small coefficient, typically 0.01. $N$ is a scaling factor that makes the loss value independent of the number of experts.

Intuition: If expert $i$ gets a disproportionate fraction of tokens ($f_i$ is high), then the loss penalizes the router for also giving it high probability ($p_i$ is high). The gradient pushes the router to lower $p_i$ for overloaded experts and raise it for underloaded ones. The product $f_i \cdot p_i$ is minimized when both distributions are uniform, i.e., $f_i = p_i = 1/N$ for all $i$.

Why use both $f_i$ and $p_i$ instead of just one? $f_i$ involves an argmax, so it is not differentiable. $p_i$ is differentiable but does not capture the actual routing decision (a token with $p_i = 0.49$ for expert A and $0.51$ for expert B still routes to B). Their product combines the hard routing reality ($f_i$) with a differentiable signal ($p_i$) that gradients can flow through.

Why Not KL Divergence from Uniform?

A natural alternative: measure how far the routing distribution is from uniform using KL divergence. If the target is a uniform distribution $u_i = 1/N$, the KL divergence from uniform is:

KL divergence from uniform
$$D_{KL}(p \,\|\, u) = \sum_{i=1}^{N} p_i \log \frac{p_i}{1/N} = \log N + \sum_{i=1}^{N} p_i \log p_i = \log N - H(p)$$

where $H(p) = -\sum_i p_i \log p_i$ is the entropy of $p$. Minimizing $D_{KL}(p \| u)$ is equivalent to maximizing the entropy of the router's soft distribution. This is well-defined and differentiable.

Worked example: KL divergence vs. load balancing loss

Consider $N = 4$ experts and $T = 8$ tokens. Suppose the router produces these soft probabilities and hard routing decisions:

Token$p_{E_1}$$p_{E_2}$$p_{E_3}$$p_{E_4}$Routed to
$t_1$0.500.300.150.05$E_1$
$t_2$0.450.350.100.10$E_1$
$t_3$0.400.250.200.15$E_1$
$t_4$0.100.550.200.15$E_2$
$t_5$0.150.500.250.10$E_2$
$t_6$0.200.200.450.15$E_3$
$t_7$0.480.220.180.12$E_1$
$t_8$0.420.280.200.10$E_1$

Hard routing fractions (count how many tokens each expert actually receives):

$f_1 = 5/8 = 0.625, \quad f_2 = 2/8 = 0.25, \quad f_3 = 1/8 = 0.125, \quad f_4 = 0/8 = 0.0$

$E_4$ is a dead expert: zero tokens routed to it.

Average soft probabilities (average each column):

$p_1 = 0.338, \quad p_2 = 0.331, \quad p_3 = 0.216, \quad p_4 = 0.115$

Now compute both losses:

KL divergence from uniform (using $p_i$ only):

$D_{KL}(p \| u) = 0.338 \ln(4 \cdot 0.338) + 0.331 \ln(4 \cdot 0.331) + 0.216 \ln(4 \cdot 0.216) + 0.115 \ln(4 \cdot 0.115)$

$= 0.338(0.302) + 0.331(0.281) + 0.216(-0.145) + 0.115(-0.768)$

$= 0.102 + 0.093 - 0.031 - 0.088 = \mathbf{0.076}$

This looks mild. The soft distribution $p$ is not that far from uniform: $[0.338, 0.331, 0.216, 0.115]$ vs. $[0.25, 0.25, 0.25, 0.25]$.

Load balancing loss (using $f_i \cdot p_i$):

$N \sum_i f_i \cdot p_i = 4 \times (0.625 \cdot 0.338 + 0.25 \cdot 0.331 + 0.125 \cdot 0.216 + 0.0 \cdot 0.115)$

$= 4 \times (0.211 + 0.083 + 0.027 + 0.0) = 4 \times 0.321 = \mathbf{1.284}$

Under perfect balance, $f_i = p_i = 0.25$, and the loss would be $4 \times 4 \times (0.25 \times 0.25) = 1.0$. So 1.284 vs. 1.0: a 28% penalty. This captures the severity of the imbalance much better.

The critical difference: the KL divergence sees only the soft probabilities, which look almost reasonable. It completely misses that $E_4$ receives zero tokens and $E_1$ receives 62.5% of them. The load balancing loss, by multiplying with the hard fractions $f_i$, directly penalizes the actual routing imbalance. The $0.625 \times 0.338$ term for $E_1$ dominates the loss, correctly flagging the expert that is hoarding tokens.

In short: KL divergence of the soft distribution is a valid regularizer, but it is blind to the hard routing decisions. The soft probabilities can look balanced while the actual routing (after the argmax) is severely skewed. The $f_i \cdot p_i$ formulation is a practical compromise: it uses the hard counts to identify which experts are actually overloaded, and the soft probabilities to provide the differentiable gradient signal needed to fix it.

The $\alpha$ Trade-off

The coefficient $\alpha$ controls how strongly the model is pushed toward balanced routing:

Key limitation: The load balancing loss fights against specialization. It pushes toward uniform routing, but the whole point of MoE is to learn non-uniform, specialized routing. Every MoE system is navigating this tension. The rest of this post covers techniques that achieve balance with less damage to expert quality.

Capacity Factor and Token Dropping

Load balancing losses are soft. Capacity factor is a hard constraint: each expert has a buffer that can hold at most $C$ tokens per batch. The capacity factor $\text{CF}$ defines this buffer size:

Expert buffer size
$$C = \text{CF} \times \frac{T \times k}{N}$$

$T$ is the total tokens in the batch, $k$ is the number of experts per token, $N$ is the total number of experts. The term $T \times k / N$ is the expected number of tokens per expert under perfectly uniform routing. $\text{CF}$ is a multiplier (typically 1.0 to 1.5) that provides headroom.

When an expert's buffer is full, additional tokens routed to it are dropped: they skip the MoE layer entirely and pass through via the residual connection. This is a form of hard enforcement: no matter what the router wants, an expert cannot become a bottleneck.

Expert 1 6 tokens capacity = 8 Expert 2 8 tokens (full) capacity = 8 dropped skip via residual Expert 3 2 tokens capacity = 8

Token dropping has a real cost: dropped tokens get no expert processing. At CF = 1.0, even mild imbalance causes drops. At CF = 1.5, you waste 50% of buffer space on padding. GShard and Switch Transformer both use this mechanism, with CF typically 1.0-1.25.

Warning: Token dropping means the model's output is non-deterministic with respect to the other tokens in the batch. The same token can get different results depending on what other tokens are in the batch and whether they fill up the buffer first. This makes debugging and reproducibility harder.

Router Z-Loss

ST-MoE (Zoph et al., 2022) identified a separate instability: router logits can grow unboundedly during training, causing numerical instability in the softmax and training crashes. Their fix is the router z-loss:

Router z-loss
$$\mathcal{L}_z = \frac{1}{T} \sum_{t=1}^{T} \left( \log \sum_{i=1}^{N} e^{g_i^{(t)}} \right)^2$$

The log-sum-exp (LSE) is not an arbitrary choice here. It is the log of the softmax partition function $Z = \sum_i e^{g_i}$, and it has a specific meaning: it is a smooth, differentiable approximation to the maximum logit.

$$\max(g_1, \ldots, g_N) \;\leq\; \underbrace{\log \sum_{i=1}^{N} e^{g_i}}_{\text{LSE}} \;\leq\; \max(g_1, \ldots, g_N) + \log N$$

When one logit dominates (say $g_1 = 50$, all others near 0), $\text{LSE} \approx 50$. When all logits are equal at value $c$, $\text{LSE} = c + \log N$. So LSE tracks the scale of the largest logits. Penalizing $\text{LSE}^2$ directly penalizes the magnitude of the dominant logit, which is exactly what causes softmax saturation.

The connection to the softmax is direct. The softmax output is $\text{softmax}(g)_i = e^{g_i} / Z = e^{g_i - \text{LSE}}$. When $\text{LSE}$ is large because one logit is much bigger than the rest, the softmax output for that logit approaches 1.0, and the gradient $\partial \text{softmax}_i / \partial g_i = \text{softmax}_i(1 - \text{softmax}_i)$ approaches zero. The router becomes frozen: it cannot adjust its routing because the gradients vanish. The z-loss prevents this by keeping LSE (and therefore the logit scale) bounded.

Intuition: The "z" in z-loss stands for the partition function $Z$. The loss penalizes $(\log Z)^2$, which is the squared log-normalizer of the softmax. Without it, the router can learn to produce very large logits (say, +100 for one expert and -100 for all others), making the softmax numerically saturated and the routing effectively a hard, non-differentiable decision. The z-loss keeps logits in a range where the softmax still has meaningful gradients, allowing the router to continue learning and adjusting throughout training. Think of it as preventing the router from becoming "too sure of itself" too early.

ST-MoE uses a coefficient of $0.001$ for the z-loss, which is small enough to not interfere with the main loss but large enough to prevent the logit blow-up they observed. This is now standard practice: DeepSeek-V2/V3, Mixtral's follow-ups, and most recent MoE models include some form of router logit regularization.

Shared Experts

DeepSeek-V2 introduced a straightforward architectural change: reserve $K_s$ experts as shared experts that process every token, regardless of routing. The remaining $N - K_s$ experts are routed as usual.

MoE output with shared experts
$$\mathbf{y} = \underbrace{\sum_{i=1}^{K_s} E_i^{\text{shared}}(\mathbf{h})}_{\text{always active}} + \underbrace{\sum_{j \in \text{TopK}} G_j \cdot E_j^{\text{routed}}(\mathbf{h})}_{\text{conditionally active}}$$
Token $\mathbf{h}$ always via router Shared Experts $E_1^s$ $E_2^s$ Routed Experts (top-k selected) $E_1$ $E_2$ $E_3$ ... $\mathbf{y}$ sum

The shared experts capture common, token-independent patterns (basic syntax, frequent co-occurrences, positional patterns). This takes pressure off the routed experts, which can focus on specialized, content-dependent processing. The result: routed experts specialize more cleanly, and the model is less sensitive to routing imbalances because the shared experts provide a reliable baseline for every token.

DeepSeek-V2 uses $K_s = 2$ shared experts alongside 160 routed experts (of which 6 are selected per token). DeepSeek-V3 continues this design with 1 shared expert and 256 routed experts.

Expert Choice Routing

Standard top-$k$ routing asks: "which experts should this token go to?" Expert Choice routing (Zhou et al., 2022) flips the question: "which tokens should this expert process?"

Each expert independently selects its top-$k'$ tokens from the batch, where $k'$ is a fixed capacity:

Expert Choice selection
$$S_i = \text{TopK}'(\mathbf{s}_i, k'), \quad \mathbf{s}_i = \mathbf{W}_r^T[:, i] \cdot \mathbf{H}$$

$\mathbf{s}_i \in \mathbb{R}^T$ is the affinity score between expert $i$ and every token. $\mathbf{H} \in \mathbb{R}^{d \times T}$ is the matrix of all token hidden states. Expert $i$ picks the $k'$ tokens with the highest affinity.

This guarantees perfect load balance by construction: every expert processes exactly $k'$ tokens. No auxiliary loss needed. No capacity overflow. No token dropping.

Token Choice (standard) $t_1$ $t_2$ $t_3$ $t_4$ $E_1$ (3) $E_2$ (1) $E_3$ (0) imbalanced Expert Choice $t_1$ $t_2$ $t_3$ $t_4$ $E_1$ (2) $E_2$ (2) perfectly balanced (tokens get variable # of experts)

The trade-off: different tokens may be processed by different numbers of experts (or zero experts). A popular token might be selected by many experts; a boring token might be selected by none. Expert Choice models handle the zero-expert case through the residual connection.

Intuition: Think of it like a draft in sports. In token choice, each player (token) picks their team (expert). Popular teams get too many players. In expert choice, each team picks their players. Every team gets exactly the right number, but some players might not get picked at all.

Expert Choice routing has a significant practical limitation: it requires knowing all tokens in a batch before routing, making it incompatible with autoregressive generation where tokens arrive one at a time. This limits it to encoder models and training-time use.

Fine-Grained Expert Segmentation

DeepSeek-MoE (and later DeepSeek-V2/V3) introduced fine-grained expert segmentation: instead of $N$ large experts, split each expert into $m$ smaller segments, giving $N \times m$ fine-grained experts total. Then increase the top-$k$ proportionally so the total activated parameters stay the same.

For example, instead of 16 experts with top-2 routing (activating 2 large FFNs), use 64 fine-grained experts with top-8 routing (activating 8 smaller FFNs with the same total size).

Fine-grained parameters
$$\text{Original: } N \text{ experts, top-}k, \text{ each of size } d_{\text{ffn}}$$ $$\text{Fine-grained: } N \times m \text{ experts, top-}(k \times m), \text{ each of size } d_{\text{ffn}} / m$$

The total activated parameters per token is identical: $k \times d_{\text{ffn}} = (k \times m) \times (d_{\text{ffn}} / m)$.

Why does this help?

DeepSeek-V2 uses $m = 4$ (splitting 16 base experts into 64 fine-grained experts with top-6 routing), and reports consistent quality improvements over the coarse-grained baseline at the same compute budget.

Auxiliary-Loss-Free Load Balancing

DeepSeek-V3 takes a radical approach: drop the auxiliary loss entirely and achieve load balance through a bias term in the router that is adjusted dynamically during training.

The idea: add a per-expert bias $b_i$ to the router logits before the top-$k$ selection (but not in the gating weight calculation):

Biased routing
$$g_i' = g_i + b_i \quad \text{(used for top-}k \text{ selection only)}$$ $$G_i = \frac{e^{g_i}}{\sum_{j \in \text{TopK}(\mathbf{g}', k)} e^{g_j}} \quad \text{(gating uses original logits)}$$

The bias $b_i$ is not learned by gradient descent. Instead, it is updated by a simple rule: if expert $i$ is overloaded in the current batch, decrease $b_i$ by a small constant $\gamma$; if underloaded, increase $b_i$ by $\gamma$.

Bias update rule
$$b_i \leftarrow \begin{cases} b_i - \gamma & \text{if expert } i \text{ is in the top-}N_{\text{over}} \text{ most loaded} \\ b_i + \gamma & \text{if expert } i \text{ is in the top-}N_{\text{under}} \text{ least loaded} \end{cases}$$

$\gamma$ is a very small step, on the order of $0.001$. This is a form of integral control: the bias accumulates small corrections over many batches, gently steering the router toward balance without ever introducing a loss term that fights the language modeling objective.

Intuition: Think of the bias as a price mechanism. If an expert is in high demand, its "price" ($b_i$) goes down, making the router slightly less inclined to route tokens there. If an expert is underused, its price goes up, attracting more tokens. Over thousands of steps, this converges to a balanced equilibrium. The critical insight is that this bias affects only the routing decision, not the gating weights. The model's actual computation (the weighted sum of expert outputs) uses the original, unbiased logits, so the language modeling signal is never corrupted by a balancing objective.

DeepSeek-V3 reports that this approach achieves better model quality than auxiliary-loss-based methods while maintaining comparable load balance. The key advantage: there is zero tension between the language modeling loss and the balancing mechanism, because they operate in completely separate channels (gradients vs. heuristic bias updates).

Soft MoE: No Hard Routing at All

Soft MoE (Puigcerver et al., 2023) eliminates discrete routing entirely. Instead of sending individual tokens to experts, it creates soft, weighted combinations of all tokens and sends those to experts.

The mechanism works in two steps:

Create soft token slots. A learned dispatch matrix $\mathbf{D} \in \mathbb{R}^{T \times S}$ computes soft combinations of all input tokens. Each of the $S$ "slots" is a weighted average of all tokens: $\tilde{\mathbf{h}}_s = \sum_t D_{ts} \cdot \mathbf{h}_t$, where $D_{ts} = \text{softmax}_t(\mathbf{h}_t \cdot \boldsymbol{\phi}_s)$ for learned slot parameters $\boldsymbol{\phi}_s$. Slots are evenly divided among experts.
Process slots through experts, then combine back. Each expert processes its assigned slots. The outputs are combined back into per-token outputs using a second soft assignment (the "combine" matrix $\mathbf{C}$), which is computed similarly: each token's output is a weighted sum of all slot outputs.

There is no top-$k$ selection, no discrete routing decision, and therefore no load balancing problem at all. Every expert always processes the same number of slots, and every token contributes to (and receives from) every expert through soft weights.

The trade-off is computational: Soft MoE requires attention-like operations (the dispatch and combine matrices involve all-to-all token interactions), which partially negates the efficiency advantage of sparse MoE. It also cannot easily scale to very large expert counts because the dispatch weights become a bottleneck.

Soft MoE has been primarily demonstrated in vision models (ViT-based). Its application to autoregressive language models is limited because the dispatch step requires access to all tokens simultaneously, which conflicts with causal generation.

Comparison Table

Technique How it works Removes aux loss? Autoregressive? Used in
Load balancing loss Auxiliary loss penalizing $f_i \cdot p_i$ No (is the aux loss) Yes Switch, GShard, Mixtral
Capacity factor Hard buffer limit per expert No (used alongside) Yes Switch, GShard
Router z-loss Penalizes large router logits No (adds another aux) Yes ST-MoE, DeepSeek-V2/V3
Shared experts Always-on experts + routed experts No (orthogonal) Yes DeepSeek-V2/V3
Expert Choice Experts select tokens (not vice versa) Yes No EC (Zhou et al., 2022)
Fine-grained experts Split experts into smaller segments No (orthogonal) Yes DeepSeek-MoE/V2/V3
Aux-loss-free bias Dynamic bias on router logits Yes Yes DeepSeek-V3
Soft MoE Soft dispatch, no discrete routing Yes (no routing) No Soft MoE (Puigcerver, 2023)

Practical Notes

What most production systems actually use

Modern MoE LLMs (Mixtral, DeepSeek-V3, Grok, DBRX) typically combine several of these techniques rather than relying on any single one. A common recipe:

How many experts?

There is no consensus. Mixtral uses 8 experts with top-2 (modest sparsity). DeepSeek-V3 uses 257 (1 shared + 256 routed) with top-8 out of 256 routed (high sparsity). More experts generally means more total parameters at the same compute cost, but the routing problem gets harder and communication overhead grows in distributed training.

Expert parallelism is the real bottleneck

The dominant practical concern is not routing quality but communication cost. In a distributed setup, each expert lives on a different device. Routing tokens to their assigned experts requires all-to-all communication across devices, which is expensive and latency-sensitive. This is why capacity factors and token dropping exist: they bound the worst-case communication load. DeepSeek-V3 addresses this with device-constrained routing, where a token can only be routed to experts on a limited set of devices, reducing the all-to-all communication to a smaller group.

MoE and fine-tuning

Fine-tuning MoE models requires care. If the fine-tuning data is small or narrowly distributed, the router may shift to routing all tokens to a few experts, and the unused experts degrade. Common mitigations: freeze the router during fine-tuning (only train expert FFNs), use a higher $\alpha$ for load balancing during fine-tuning, or use LoRA on expert weights to limit the parameter changes.

Evaluating expert utilization

Track these metrics during training to catch problems early:

Found this useful? Share it.