[P] Fused MoE Dispatch in Pure Triton: Beating CUDA-Optimized Megablocks at Inference Batch Sizes
I built a fused MoE dispatch kernel in pure Triton that handles the full forward pass for Mixture-of-Experts models. No CUDA, no vendor-specific code.
On Mixtral-8x7B (A100), it beats Stanford’s Megablocks at inference-relevant batch sizes (131% at 32 tokens, 124% at 128 tokens). At larger batches Megablocks’ hand-tuned CUDA pulls ahead as expected.
Two main contributions:
- Fused gate+up projection – both GEMMs share the same input tile load, SiLU computed in registers. Eliminates ~470MB of intermediate buffers per forward pass (35% memory traffic reduction).
- Block-scheduled grouped GEMM – precomputed block_id to (expert_id, offset) mapping handles variable-sized expert batches in a single kernel launch without padding.
Tested across Mixtral-8x7B, DeepSeek-V3 (256 experts), and Qwen2-MoE. Full test suite passes on AMD MI300X with zero code changes.
Code: https://github.com/bassrehab/triton-kernels
Writeup: https://subhadipmitra.com/blog/2026/fused-moe-dispatch-triton/
submitted by /u/bassrehab
[link] [comments]
Like
0
Liked
Liked