BlackMamba: Mixture of Experts for State-Space Models

The development of Large Language Models (LLMs) built from decoder-only transformer models has played a crucial role in transforming the Natural Language Processing (NLP) domain, as well as advancing diverse deep learning applications including reinforcement learning, time-series analysis, image processing, and much more. However, despite their scalability and strong performance, LLMs built from decoder-only transformer models still face significant shortcomings. Although expressive, the attention mechanism in transformer-derived LLMs requires high computational resources during both inference and training, necessitating substantial memory for the sequence length and quadratic FLOPs. This high computational requirement limits the context length of transformer models, making autoregressive generation tasks proportionally expensive with scale, and hinders learning from continuous data streams and the capability for truly unlimited sequence processing.

In recent times, State Space Models (SSMs) have demonstrated remarkable capabilities and performance, competing with transformer-architecture models in large-scale modeling benchmarks while achieving memory complexity as a function of sequence length and linear time. Moreover, Mamba, a recently released State Space Model, has shown outstanding performance in a range of language modeling and long-sequence processing tasks. Simultaneously, Mixture of Expert (MoE) models have also shown impressive performance while significantly reducing the latency and computational costs of inference, albeit at the expense of a larger memory footprint. Building on Mamba and MoE models, this article will discuss BlackMamba, a novel architecture that combines the Mamba State Space Model with MoE models to leverage the benefits offered by both frameworks. Experiments on BlackMamba have demonstrated its ability to outperform the existing Mamba framework and transformer baselines in both training FLOPs and inference. The exceptional performance of the BlackMamba framework shows that it can effectively combine the abilities of the Mamba and MoE frameworks, offering fast and cost-effective inference from MoE with linear-complexity generation from Mamba.

This article aims to cover the BlackMamba framework in depth. We explore the mechanism, methodology, and architecture of the framework, along with its comparison to state-of-the-art image and video generation frameworks. Let’s get started.

The progression of Large Language Models (LLMs), particularly those based on decoder-only transformer architectures, has notably influenced the Natural Language Processing (NLP) field and expanded into various deep learning applications, including reinforcement learning, time-series analysis, image processing, and beyond. Nonetheless, despite their scalability and robust performance, these decoder-only transformer-based LLMs encounter notable challenges. The attention mechanism, a key feature of transformer-based LLMss, demands extensive computational resources for both inference and training. This involves a need for memory that grows with the sequence length and computational operations (FLOPs) that increase quadratically. Such intensive computational needs restrict the models’ context length, elevate the costs of autoregressive generation tasks as the model scales, and hinder the models’ ability to learn from continuous data streams or process sequences of unlimited length efficiently. 

Significant efforts have been made in the past few years in an attempt to overcome these limitations, and attention has been shifted towards devising architectural alternatives to the canonical dense attention transformer models with SSMs and MoE models being the most promising candidate architectures. The key benefit reaped by favoring State Space Models over transformer architecture models is the linear computational complexity with respect to input sequence length offered by SSMs as opposed to the quadratic complexity offered by transformers. Theoretically, linear computational complexity with respect to input sequence length enables State Space Models to process larger sequences than transformer-architecture models for a given FLOPS or Floating-point operations per second budget, and to render autoregressive generation constant in compute without a KV cache. Recently developed State Space Models including Mamba, RetNet and a few others have demonstrated efficient long-sequence inference and training, along with competitive language modeling task performance to transformers with similar scaling properties. On the other hand, Mixture of Expert models architectures is gaining popularity as an alternative to dense transformers since it facilitates a significant reduction in inference and training FLOPs essential for achieving comparable quality to a dense model. MoE (Mixture of Experts) models operate by activating only a sparse selection of the total parameters during a single forward pass. They utilize a routing function to determine which ‘experts’ are called into action based on the given context. This approach creates a separation between the computational cost of inference and the total number of parameters, allowing for enhanced performance within a fixed inference budget, albeit with an increased number of parameters and a larger memory requirement.

This advancement in architecture offers notable benefits over traditional transformers and represents an exciting direction for further development. We posit that integrating these enhancements into a combined Mamba-MoE model could significantly accelerate language modeling capabilities and efficiency beyond that of standard transformer models. The anticipated advantages of a Mamba-MoE architecture compared to a traditional dense transformer model include:

Mamba: Achieves linear computational complexity relative to the input sequence length for both training and inference phases. It enables autoregressive generation to occur in a constant time frame and with constant memory usage.

MoE: Offers the inference speed and training computational efficiency comparable to a smaller, dense baseline model while maintaining a level of model quality that rivals that of a model with an equivalent number of parameters as the denser version.

With that being said, it is essential to state that transformer architecture models are still state of the art, and have demonstrated consistent and remarkable strong performance on language modeling tasks and sequence processing tasks. At its core, the transformer architecture employs self-attention that performs a quadratic all-to-all comparison of the dot product similarities between the embeddings of different tokens in a sequence, and performs a linear map to an output vector. The transformer model consists of self-attention blocks stacked between MLP or Multi-Layer Perceptron blocks that further consist of a two-layer MLP with a given activation function. 

BlackMamba : Architecture and Methodology

State Space Models

State Space Models belong to the group of sequence models with linear complexity with respect to the length of the input sequence. The architecture of State Space Models aligns more with Recurrent Neural Networks and Convolutional Neural Networks rather than attention-based architecture, and is inspired from a continuous dynamical system that maps a 1-dimensional function through an implicit latent space. A linear dynamical system makes parallel computations efficient using either an associative or a convolution scan. In practical scenarios, the recurrent nature of State Space Models has been the reason why it is still to be adopted on highly-parallel AI hardware like GPUs. However, the emergence of SSMs like RWKV and Mamba have used parallel scan kernels to map recurrent operations efficiently to GPUs, thus facilitating the training of novel architectures with efficiency comparable to those achieved by transformer models. 

The inherent quadratic complexity in relation to sequence length within transformers is a well-known limitation that impedes reasoning and comprehension over very long contexts. Recent innovations have introduced the idea of extending the context length, enabling transformers to be trained on a feasible scale before being applied to much longer contexts during inference. Despite these advancements, the inference process still demands a considerable amount of computational resources and memory, especially for maintaining the Key-Value (KV) cache, making it a resource-intensive endeavor. Recent research efforts have focused on enhancing the expressive capabilities of state-space models by incorporating input-dependent gating mechanisms, akin to the Query, Key, Value (QKV) matrices found in attention mechanisms. 

These efforts aim to preserve the inherently linear progression of state-space recursion, allowing for efficient execution through either convolution or a selective scan process. This approach significantly narrows the performance disparity with transformers in practical applications. Among these advancements, Mamba stands out as a state-space model that mirrors the objectives of prior research, showing impressive performance levels comparable to transformers at scales up to 2.8 billion parameters. It achieves this by applying input-dependent gating to the inputs of the state-space model (SSM) recursion, all the while ensuring efficient computation through the use of bespoke selective scan kernels.

Mixture of Expert Models

Mixture of Expert (MoE) models achieve a separation between the inference cost and the total parameter count by selectively activating parameters during the forward pass. Instead of using all parameters, these models direct tokens to specific Multilayer Perceptron (MLP) experts. Ideally, each expert is tailored to process a particular type of input, with a routing mechanism, essentially a compact neural network, determining the most suitable expert for each token. This approach aims to preserve the comprehensive expressive power of a model with an equivalent number of parameters in a denser configuration, but with considerably reduced computational demands. Typically, the router is a mapping of the linear layers from tokens to expert indices with each expert simply being a standard transformer Multilayer Perceptron. However, developers are yet to figure out the optimal training method for the router since the expert assignment problem is non-differentiable, and Mixture of Expert models often struggle with load balancing and training stability between different experts for hardware efficiency. 

Architecture

At its core, BlackMamba employs a standard transformer model consisting of interleaved MLP blocks and attention blocks added in sequence along a residual stream. Now, a majority of Mixture of Expert models simply replace the multilayer perceptron blocks with a routed expert layer. On the other hand, the BlackMamba framework not only replaces the multilayer perceptron block in the transformer with a routed expert layer, but also replaces the attention layer with a Mamba State Space Model layer. The architecture of the BlackMamba framework is demonstrated in the following figure. 

Training and Dataset

The BlackMamba model is trained on over 300 billion tokens on a custom dataset, and uses the SwiGLU activation function for the expert multilayer perceptrons. The framework trains with 8 experts, a number that developers found to be the right balance and trade off between the memory footprint and inference cost of the model. The custom dataset used to train the BlackMamba framework consists of a mixture of already existing open source datasets including Starcoder, SlimPajama, Pile, and more. The following table demonstrates the weights of each of the dataset used for training the BlackMamba framework. Overall, there are 1.8 trillion tokens in the dataset. 

BlackMamba : Results

To ensure a fair comparison between Mamba and BlackMamba, developers have trained both the models with the same training parameters on the same training data. The BlackMamba framework is able to outperform both Mamba and transformer models for identical forward pass model size at the inference time as well as training Floating-point operations per second. The following figure demonstrates the time taken to generate a sequence of a given length autoregressively from an initial one-token prompt as a function of the sequence length. 

Furthermore, the latency benefits of both the Mixture of Expert and Mamba models are combined in the BlackMamba framework resulting in significantly faster inference times when compared against transformer models, pure Mamba models, and MoE models. Furthermore, the inference advantage of the BlackMamba framework is directly proportional to the sequence lengths, making BlackMamba extremely effective at long sequence generation. Moving along, the following figure illustrates the number of tokens assigned to the BlackMamba models with 340 million and 640 million parameters respectively. As it can be seen, a majority of the layers demonstrate a high level of expert balance as a result of the improved Sinkhorn algorithm implemented by the BlackMamba models. 

The following table covers the evaluation scores of the BlackMamba framework compared against a range of open-source pre-trained language models. As it can be observed, the BlackMamba framework is able to compete and outperform with a majority of the frameworks across all baselines. Furthermore, it is worth noting that the models that outperform BlackMamba have considerably higher number of parameters, and the gap in performance is minimal, indicating the ability of the BlackMamba framework with less parameters. 

Final Thoughts

In this article, we have talked about BlackMamba, a novel architecture that combines the Mamba State Space Model with Mixture of Expert models to reap the benefits offered by both these frameworks. Experiments on BlackMamba have demonstrated it to outperform the existing Mamba framework and transformer baselines in both training FLOPs and inference. The exceptional performance of the BlackMamba framework demonstrates that it is able to inherit and combine the abilities of the Mamba and MoE frameworks exceptionally well since it combines the cheap and fast inference from MoE with linear-complexity generation from Mamba. We have talked about how the architecture of the BlackMamba framework is able to outperform strong trained Large Language Models, existing Mamba framework, and Mixture of Expert models in terms of training FLOPs and inference cost. Furthermore, the BlackMamba framework also inherits the generation FLOPs and reduced training from both Mixture of Expert models and Mamba framework simultaneously. 

 

Credit: Source link

Comments are closed.