Unbiased Gradient Estimation with Balanced Assignments for Mixtures of Experts


Training large scale mixture of experts models efficiently on modern hardware requires balanced assignment of datapoints to experts, to meet a fixed computation capacity per expert. Recently proposed heuristic or exact assignment procedures lack a probabilistic interpretation and use biased estimators for training, and are therefore not well understood. As an alternative, we propose two unbiased estimators based on principled stochastic assignment procedures: a simple one that skips datapoints which exceed expert capacity, and a more involved one that samples a perfectly balanced assignment using an extension of the Gumbel-Max trick. Experiments on a toy experiment suggest that the `skip'-estimator is more effective than balanced sampling, while both outperform biased alternatives.