Key Points
- Branch-Train-MiX (BTX) is a method used for training Large Language Models (LLMs) to possess capabilities in multiple specialized domains, such as coding, math reasoning, and world knowledge.
- BTX starts from a seed model, which is branched to train experts in embarrassingly parallel fashion with high throughput and reduced communication cost, and later brings together their feedforward parameters as experts in Mixture-of-Expert (MoE) layers and averages the remaining parameters.
- BTX achieves the best accuracy-efficiency tradeoff compared to alternative approaches and generalizes two special cases, the Branch-Train-Merge method, and sparse upcycling.
- Recent work by Li et al. proposed the Branch-Train-Merge (BTM) method for embarrassingly parallel training of LLMs without any synchronization for improving the throughput of pretraining, which results in multiple independent LLMs that do not share any parameters, leading to the lack of a unified single model for further supervised finetuning or reinforcement learning from human feedback.
- The paper introduces the Mixture-of-Experts (MoE) approach where only a subset of parameters are active at any given time and is used to scale deep networks.
- BTX is more compute efficient with higher training throughput and more balanced performance across tasks in different domains compared to sparse upcycling and Branch-Train-Merge.
- BTX models demonstrate improved performance over the seed model on tasks across various domains, especially bridging the gap with specialized models on math and code-related tasks while retaining performance on the original capabilities where specialized models suffer from catastrophic forgetting.
- BTX methods outperform BTM on all tasks, demonstrating the benefits of learned routing through MoE finetuning.
- The routing decisions with load balancing introduce improvements in coding tasks, but there is a degradation in math tasks, which is inferred from the comparison of BTX and several routing methods.
- Comparison of routing decisions for different router designs and downstream tasks shows slight variations in routing distributions in the initial layers, but they quickly become indistinguishable from layer to layer.
- In Switch routing, the Math expert becomes dominant across tasks in the final model layer, which is an exception compared to other models.
- Code expert is a dominant force in the Code domain in Top-2 routing with load balancing, contrasting with other models where Math expert prevails across domains without load balancing.
- The addition of load balancing shifts the routing probability distribution in the Code domain, with the phenomena of the dead expert observed where routing probability to the Code expert shifts to 0. With load balancing added, probability distributions across experts look more similar, with slightly higher expectations for the Code expert.
- Experts specialize in different domains, as observed in the per-task distribution. The GSM8K task prefers Code and Llama-2 experts, while the Math task relies more on in-domain expert. This is attributed to the grade school math word problems in the GSM8K dataset, which require common sense knowledge and basic arithmetic operations, aligning with the training data of the Math expert for college-level math knowledge.
- In the Reasoning domain, tasks equally rely on Math and generalist LLM’s expertise.
- Top-2 routing with load balancing ensures a more uniform distribution of the load between experts compared to other routing methods across all layers.
- Routing probabilities per expert across different layers for the Human Eval task are compared for top-2 routing with and without load balancing, and the differences are highlighted.
- The routing decision of the tokens in the Math and Reasoning domains shows preferences for specific experts, with the GSM8K task preferring Code and Llama-2 experts, the Math task relying more on in-domain expert, and load distribution between Math and LLaMa-2 7B experts in the Reasoning domain.
Summary
The paper proposes a method named Branch-Train-MiX (BTX) for training Large Language Models (LLMs) to possess capabilities in multiple specialized domains, such as coding, math reasoning, and world knowledge. BTX starts from a seed model, which is branched to train experts in an embarrassingly parallel fashion with high throughput and reduced communication cost. After individual experts are asynchronously trained, BTX brings together their feedforward parameters as experts in Mixture-of-Expert (MoE) layers and averages the remaining parameters, followed by an MoE finetuning stage to learn token-level routing. This method generalizes two special cases, the Branch-Train-Merge method, and sparse upcycling. Compared to alternative approaches, BTX achieves the best accuracy-efficiency tradeoff.
Previous Work on Branch-Train-Merge (BTM)
The paper highlights the challenges associated with training Large Language Models (LLMs) using synchronized training, such as high communication cost and vulnerability to hardware failures. Previous work on Branch-Train-Merge (BTM) proposed embarrassingly parallel training of LLMs without synchronization to improve the throughput of pretraining by creating multiple copies of a seed LLM and separately training each copy on different subsets of data, making each LLM an expert specializing in its own data distribution.
The paper also describes how the Mixture-of-Experts (MoE) approach is used to scale deep networks while reducing the total number of parameters that are active at any given time. The key advantage of BTX compared to MoE is that expert training is embarrassingly parallel and asynchronous, reducing communication cost and increasing training throughput.
Reference: https://arxiv.org/abs/2403.078...