Insight
Designing a small Mixture-of-Experts that actually routes
This note documents the design of a small Mixture-of-Experts (MoE) language model I’ve been building as a personal, open-source research project. “Small” here means ~363M total parameters, ~192M active per token, ~1.15B FLOPs-equivalent per token across recursive loops — modest enough to pretrain on a single H100 for a few days and still produce useful signal about routing dynamics at the small-model end of the curve.
An earlier attempt at this architecture trained cleanly but never learned to route. The router sat at exactly uniform softmax for ten thousand pretrain steps, so the experts saw a perfectly mixed token diet and never specialised. This write-up is the design that replaces it, and the reasoning behind each non-obvious choice.
The failure mode worth naming
Routing collapse in small MoEs is rarely one bug — it’s a self-reinforcing loop between a few independently reasonable decisions. In the earlier attempt three things conspired:
- A dense feed-forward path running parallel to the experts. With
dense_hidden=4096next toexpert_hidden=1024, the dense FFN had four times the capacity per token, absorbed most of the gradient, and let the MoE branch stay essentially optional. Gates on the first two blocks eventually went negative — the model was actively subtracting the expert contribution. - A capacity cap tight enough to mask any routing decision. With capacity factor 1.25 and a small expert pool, every expert received an almost uniform mix of tokens regardless of what the router asked for. Without per-expert specialisation, the router had no usable signal to learn from. Self-reinforcing degeneracy.
- An “importance” auxiliary loss whose minimum is the uniform
distribution. The router converged to exactly the thing the aux loss
rewarded most:
prob_entropy = ln(11) ≈ 2.40, perfectly flat softmax.
A 500-step recovery experiment — freezing everything except the router and experts, training with a diversity loss — produced no task-loss recovery. The collapse wasn’t fixable from that initialisation; the architecture needed to change.
Architecture summary
The new design removes the dense FFN entirely, makes the shared expert carry the “common knowledge” role, gives the routed experts real capacity, and replaces the importance loss with a Switch-style balance loss that penalises imbalance without requiring uniformity.
| Component | Earlier design | Current design | Why |
|---|---|---|---|
| Dense FFN | hidden=4096 (always on) | removed | Was doing the work the routed experts should have done. |
| Shared expert | hidden=1024 | hidden=4096 (SwiGLU) | Absorbs the dense FFN’s role as a stable, always-on path. |
| Routed experts | 11 × hidden=1024 | 8 × hidden=2048 | Fewer, larger experts — each has real representational room. |
| Routing | top-2, sigmoid gates | top-2, softmax + renormalised | Mixtral-style; gates on chosen experts sum to 1. |
| Aux loss | importance (min @ uniform) | Switch balance (E · Σ fᵢ · pᵢ) | Penalises imbalance without forcing uniformity. |
| Capacity factor | 1.25 (tight) | 2.0 (loose) | Router preferences actually decide routing in training. |
| Expert init | random | orthogonal, per-expert | Break symmetry structurally, not stochastically. |
Parameter budget (measured)
Embedding (32K × 1536, tied with LM head) : 50,331,648 always active
Attention × 3 blocks (qkv + out_proj) : 28,311,552 always active
Shared expert × 3 (SwiGLU h=4096) : 56,623,104 always active
Routed experts: 3 blocks × 8 experts × h=2048 : 226,492,416 physical
→ with top-2 of 8, active per token : 56,623,104
Router, loop embeddings, adapters, norms : ~1M
Total physical parameters : 362,720,259 (~363M)
Active per token : ~191,889,408 (~192M, 52.9%)
Effective compute per token (× 6 loops) : ~1.15B FLOPs
LM head weights are tied with the input embedding, saving roughly 50M parameters versus an untied head at this vocabulary size.
Three physical blocks are reused across six recursive loops, so the parameter count above hides an 18-effective-layer model. That recursion held up under the earlier attempt’s compute budget without instability and is kept as-is.
Block forward
def forward(self, x, loop_idx):
x = x + self.attn(self.ln1(x), loop_idx)
# MoELayer returns (shared_out, routed_out) so the Block can gate them
# independently. Shared expert is always full weight; only the routed
# branch passes through moe_gate.
shared_out, routed_out = self.moe(self.norm_moe(x), loop_idx)
x = x + shared_out + self.moe_gate * routed_out
return x
moe_gate initialises at 1.0 so there is no dense crutch to hide behind
— the routed branch is on the critical path from step 0. The scalar is
trainable and can drift downwards if the routed experts turn out to be
net-harmful (a useful diagnostic), but the shared expert contribution is
never gated.
Routing
Top-k softmax with renormalised gates is the Mixtral pattern: each picked expert carries a real weight in the output, so the router receives direct task gradient through the gate on the chosen path.
router_logits = router(x + loop_embedding[loop_idx]) # (N, E=8)
probs = softmax(router_logits, dim=-1) # (N, E)
raw_top, top_idx = probs.topk(K=2, dim=-1) # (N, K)
top_probs = raw_top / raw_top.sum(-1, keepdim=True) # gates sum to 1
# Switch balance loss, generalised to top-k
one_hot = F.one_hot(top_idx, E).to(probs.dtype) # (N, K, E)
f = one_hot.sum((0, 1)) / (N * K) # sum(f) = 1
p = probs.mean(0) # sum(p) = 1
balance_loss = E * (f * p).sum() # minimum = 1.0
# Capacity scales with K because each token picks K experts
capacity = ceil(capacity_factor * K * N / E) # 2.0 * 2 * N / 8 = N/2
The Switch balance loss is
The minimum is 1.0 when both f and p are uniform — but unlike the
importance loss it does not force uniformity; it just makes large
imbalances expensive. The model is free to pick a sharp per-token
distribution as long as the global load over a batch stays roughly
balanced.
The aux-loss coefficient lives at 0.03. Sanity runs showed Switch’s
canonical 0.01 was too weak for an 8-expert pool at this scale: by step
190 the router was sharp per token (per_token_entropy ≈ 0.66,
raw_max_prob ≈ 0.65) but globally collapsed to two active experts
(marginal_entropy ≈ 0.71, drop_rate = 0.44). Raising to 0.03 brought
the marginal entropy back to ~1.08 and the drop rate down to ~0.28
without hurting task loss. Pushing further to 0.10 gave diminishing
returns. The aux loss is added to the total only when model.training is
True; evaluation loss is pure task cross-entropy so perplexity isn’t
polluted by a regulariser choice.
Early exploration via Gumbel noise
Aux loss helps once experts are alive but doesn’t prevent the first few
hundred steps killing a couple of them outright. The schedule that worked
in extended sanity runs: start at router_gumbel_tau = 0.5 and anneal
linearly to 0.0 over the first 4,000 steps, then leave 1,000
no-noise steps before the 5k health gate. An earlier 1,000-step anneal
expired while the 3,000-step LR warmup was still ramping
task-gradient pressure, and clean-router marginal entropy fell from 1.79
at step 800 to 0.75 at step 1,000. The longer schedule keeps
exploration through peak LR.
Telemetry logs both the noisy dispatch and a parallel “clean” router
forward pass with tau = 0. Every health metric below is computed on the
clean router so noise-assisted balance can’t dress up a router that
hasn’t actually learned.
Capacity, drops, and evaluation
Capacity factor of 2.0 means each expert can hold up to 2 · K · N / E
tokens per batch (N/2 here). Tokens beyond capacity for a given expert
are dropped from that expert’s branch in training only. In eval mode the
capacity is set to N · K so drops never occur — generation is
chunk-stable by construction, which matters for autoregressive decoding
where a token being dropped mid-sequence would change the model’s output
distribution per batch size.
Expert initialisation
Random per-expert weights aren’t actually different enough to avoid gradient symmetry — small QR-orthogonal blocks give the optimiser something structural to break.
# For each expert's w_gate, w_up, w_down:
# 1. Sample a random matrix per expert (different seed each)
# 2. QR-decompose to get orthogonal columns
# 3. Rescale to match the standard fan-in init variance
# 4. Apply after HF post_init, so the framework's re-init can't clobber it
Each expert starts in a different feature subspace. The router has signal to discriminate from step 1, rather than waiting for stochastic gradient drift to break ties.
Optimiser × routing ablation
A 2×2 ablation at 1,200 foundation-matched steps drove the production optimiser choice:
| Cell | Optimiser | Aux | Final task loss | clean entropy min | balance | Verdict |
|---|---|---|---|---|---|---|
| A | Lion | 0.10 | ~7.4 | ~0.002 | ~1.2 | Baseline; partial late-warmup collapse. |
| B | Lion | 0.00 | ~7.6 | 0.000 | ~3.9 | Bias-only controller collapses by step 500. |
| C | Muon | 0.10 | 3.43 | 0.105 | 1.02 | Production recipe. |
| D | Muon | 0.00 | ~4.7 | ~0.001 | ~2.3 | Matches C’s task loss with a degenerate router. |
Three load-bearing conclusions:
- The Muon hybrid optimiser is a ~4-nat task-loss win at this scale,
regardless of routing scheme. C and D both crossed
task < 5.0by step 450, while A and B were still around 7.2 there. Newton–Schulz orthogonalisation of the momentum update equalises step sizes across singular directions — exactly what helps cold experts get useful gradient instead of being shrunk into oblivion by per-coordinate variance scaling. - Gradient aux loss is necessary for router health, regardless of optimiser. A bias controller alone collapses the raw router under both Lion and Muon. The recent “auxiliary-loss-free” routing claim is stable at the very large-model end; at 363M on this curriculum it isn’t.
- C is the recipe. Best loss, best balance, best entropy floor, best margin. D buys C’s loss while letting the raw router degenerate to two or three active experts — fine at early-curriculum difficulty, but later phases need the unused capacity D throws away.
Lion stays available as a fallback (optimizer_name="lion") for A/B
comparisons; the Muon hybrid is the default. Muon handles the 2D matrix
weights (attention projections, expert weights); AdamW continues to manage
embeddings, norms, the router, and scalar parameters where Muon’s
orthogonalisation step is either undefined or unhelpful.
Training plan
Three phases on the same recursive backbone, with each phase widening the context window and the data mix:
- Foundation — 10k steps, seq_len 2048, B=8, accum=8. FineWeb-Edu (60%) + CodeParrot (40%). Peak LR 6e-4 on the AdamW side, 0.02 on the Muon side, both cosine. Checkpoint every 1,000 steps.
- Knowledge — 30–60k steps, seq_len 4096, B=4, accum=16. FineWeb-Edu (50%) + CodeParrot (30%) + Wikipedia (20%). Cosine continues from the previous phase.
- Instruction — 10–20k steps, seq_len 8192. SmolTalk + Evol-Code + OpenHermes, with HRA adapters on top so the pretrained weights stay clean.
Health gate at step 5k
The earlier attempt’s failure mode would have evaded any single-metric check, because batch-marginal router entropy stayed high while per-token entropy also stayed high — the router never committed. Phase 1 trips a hard stop unless all four hold on the clean router by step 5,000:
clean_per_token_entropy < 1.5(down from ~ln 8≈ 2.08) — router differentiating per token, not just per batch.clean_raw_max_prob > 0.30(up from ~1/8) — router has a clear primary pick for most tokens.clean_top_margin > 0.10— meaningful gap between top-1 and top-2.clean_marginal_entropy > 1.8— global balance preserved (no dead experts).
If those four don’t move together by step 5k the issue is structural — likely recursive weight sharing compounding across loops, or a capacity vs. task-diversity mismatch — and the run stops.
What this design bets on
- Top-k softmax with renormalised gates forces commitment: the chosen experts carry real weight in the output, and the router receives direct task gradient through the gate.
- No dense FFN — the routed branch is on the critical path for every token. The model cannot quietly ignore the experts.
- Switch balance loss penalises imbalance without rewarding uniformity. The router can be sharp per token as long as the batch-wide load stays balanced.
- Larger experts (hidden=2048) give each one room to represent useful structure, not a rank-16 perturbation on a shared backbone.
- Orthogonal per-expert init starts experts in genuinely different subspaces, so gradients push them apart organically.
- Loose capacity (2.0) in training, no drops in eval — router preferences decide routing during learning, and generation is chunk-stable at inference time.
- Annealed Gumbel exploration through LR warmup keeps experts alive while task gradient ramps, then yields to the clean router with 1,000 steps to spare before the health gate.
- Per-loop, per-expert balance-bias controller counter-rotates clean top-k load imbalance once per optimiser step, on top of the Switch aux loss. The bias is persistent and part of the deployed routing path, loop-specific because capacity is enforced per MoE call.
What I’m watching
The honest list of things that could still go wrong, in rough order of likelihood:
- Early collapse to two or three active experts during the LR
warmup. Sanity runs at coefficients
0.03,0.10(noisy) and0.10(clean) all eventually collapsed under warmup pressure. The current fix is the combination of Switch aux loss and the per-loop bias controller — but this is the failure mode I expect the model to keep testing. - Expert count vs. task diversity. Eight experts feels right at this model size, but if specialisation doesn’t emerge after the foundation phase, the fallback is four experts at hidden=4096.
- Shared expert dominating. If the shared path does 95% of the work and routed experts stay weak, the shared expert hidden dimension comes down to 2048.
- Cross-loop routing consistency. The same router is reused across
six recursive loops, with only
loop_embedding[loop_idx]to distinguish them. Telemetry will show whether the router learns genuinely loop-specific preferences or just averages them.
If the model gets through Phase 1’s health gate, the most interesting question stops being “did the router learn?” and becomes “what does each expert specialise on?” — which is the question this whole project exists to answer.