Unbiased Gradient Estimation with Balanced Assignments for Mixtures of Experts

Abstract

Training large scale mixture of experts models efficiently on modern hardware requires balanced assignment of datapoints to experts, to meet a fixed computation capacity per expert. Recently proposed heuristic or exact assignment procedures lack a probabilistic interpretation and use biased estimators for training, and are therefore not well understood. As an alternative, we propose two unbiased estimators based on principled stochastic assignment procedures: a simple one that skips datapoints which exceed expert capacity, and a more involved one that samples a perfectly balanced assignment using an extension of the Gumbel-Max trick. Experiments on a toy experiment suggest that the `skip'-estimator is more effective than balanced sampling, while both outperform biased alternatives.

Publications