Jamba: A Hybrid Transformer-Mamba Language Model
Mike Young
Posted on April 11, 2024
This is a Plain English Papers summary of a research paper called Jamba: A Hybrid Transformer-Mamba Language Model. If you like these kinds of analysis, you should subscribe to the AImodels.fyi newsletter or follow me on Twitter.
Introduction
The paper introduces Jamba, a new publicly available large language model with a novel hybrid architecture. Jamba combines Transformer layers with Mamba layers (a state-space model) and a mixture-of-experts component. This hybrid design aims to improve performance, increase throughput, and maintain a manageable memory footprint.
The key novelty of Jamba is its combination of the Transformer architecture, known for its strong performance, with the Mamba architecture, which excels at handling long contexts and efficient training. By varying the ratio of Transformer and Mamba layers, Jamba can balance memory usage, training efficiency, and long context capabilities.
The paper discusses previous attempts to combine attention and state-space models, noting that Jamba is the first production-grade model of this type. It also incorporates mixture-of-experts layers, allowing for increased model capacity without proportionally increasing compute requirements.
Jamba's performance is comparable to similarly sized models like Mixtral-8x7B and Llama-2 70B, but excels on long-context evaluations. It also boasts high throughput and can fit on a single GPU even with contexts over 128K tokens.
The authors have released Jamba (12B active parameters, 52B total parameters) under an open-source license to encourage further study and optimization by the community. However, they note that the released model is a pretrained base without additional tuning or moderation mechanisms.
Model Architecture
The provided text introduces the Jamba architecture, a hybrid decoder that combines three key components: Transformer layers, Mamba layers (a recent state-space model), and a mixture-of-experts (MoE) module. These components are referred to as a Jamba block. The text refers to Figure 1 for an illustration of this architecture, although the figure itself is not provided.
Figure 1: (a) A single Jamba block. (b) Different types of layers. The implementation shown here is with l=8đť‘™8l=8italic_l = 8, a:m=1:7normal-:đť‘Žđť‘š1normal-:7a:m=1:7italic_a : italic_m = 1 : 7 ratio of attention-to-Mamba layers, and MoE applied every e=2đť‘’2e=2italic_e = 2 layers.
The paper describes the Jamba architecture, which combines transformer, Mixture of Experts (MoE), and Mamba elements to balance memory usage, throughput, and model quality. Key points:
- Total model parameters can be misleading for MoE models, as only a subset of parameters are active during inference.
The key-value (KV) cache size for storing attention keys/values is a limiting factor, especially for long sequences. Jamba aims for a smaller KV cache compared to standard transformers.
Replacing attention layers with more compute-efficient Mamba layers improves throughput, especially for long sequences.
Jamba blocks contain a mix of attention and Mamba layers, with multi-layer perceptrons (MLPs) that can be replaced with MoE layers.
Configurable parameters include: number of layers, attention-to-Mamba ratio, MoE frequency, number of experts per layer, and number of top experts used.
Increasing the Mamba ratio reduces KV cache size but may lower quality. More MoE experts increases capacity but uses more memory.
Mamba layers use RMSNorm for stable training at scale. No explicit positional embeddings are used.
Other standard components like grouped query attention and SwiGLU activations are used.
The architecture allows flexibility in optimizing for different objectives by tuning the configurable parameters.
Reaping the Benefits
The paper describes the implementation details of Jamba, a large language model designed to fit on a single 80GB GPU while achieving high performance in terms of quality and throughput.
Jamba consists of four Jamba blocks, each with 8 layers. The ratio of attention to Mamba layers is 1:7. The model uses a mixture of experts (MoE) instead of a single MLP every other layer. It has 16 experts in total, with 2 top experts used at each token.
This configuration was chosen to balance model quality, compute requirements, and memory transfers while fitting on an 80GB GPU. It allows for up to 1M token context length during training, and the released model supports up to 256K tokens.
In terms of throughput, Jamba achieves 3x higher throughput than Mixtral on a single GPU with a batch size of 16 and 8K context length. On 4 GPUs with 128K context length, Jamba's throughput is 3x higher than Mixtral's, despite not being optimized for pure transformer models like Mixtral.
The paper highlights that Jamba enables significantly longer context lengths compared to other recent open models like Mixtral and Llama-2-70B when fitting on an 80GB GPU.
Training Infrastructure and Dataset
The model was trained using NVIDIA H100 GPUs and an in-house proprietary framework that enabled efficient large-scale training through techniques like FSDP, tensor parallelism, sequence parallelism, and expert parallelism. The model, named Jamba, was trained on an in-house dataset containing text data from the Web, books, and code. This dataset was last updated in March 2024. A data processing pipeline with quality filters and deduplication methods was employed.
Evaluation
The paper presents performance results of the proposed Jamba model on various academic benchmarks and long-context evaluations. Key points:
Academic Benchmarks:
- Jamba performs comparably or better than leading publicly available models like Llama-2 70B and Mixtral on benchmarks covering reasoning, reading comprehension, and others.
- Despite having fewer total parameters (52B) than Llama-2 70B, Jamba achieves strong performance while offering up to 3x better throughput.
Long-Context Evaluations:
- Jamba can handle contexts up to 1M tokens, with the released model supporting 256K tokens.
- It shows excellent performance on the needle-in-a-haystack evaluation, which tests recall of statements in long contexts.
- On naturalistic long-context QA benchmarks (up to 62K tokens), Jamba outperforms Mixtral on most datasets and has better average performance.
- Jamba's efficiency shines on these long-context tasks, offering much better throughput.
The paper highlights Jamba's ability to reach state-of-the-art performance while leveraging the benefits of a hybrid architecture with improved efficiency.
Ablations and Insights
The section discusses ablation experiments conducted to evaluate different design choices for the Jamba architecture, which combines attention and Mamba (state-space) layers. Key findings include:
Combining attention and Mamba layers improves performance over pure attention or pure Mamba models. A ratio of 1 attention layer to 7 Mamba layers works well.
The pure Mamba model struggles with in-context learning capabilities, while the hybrid Attention-Mamba model exhibits in-context learning similar to vanilla Transformers. Visualizations suggest the attention layers develop induction heads that support in-context learning.
Adding a Mixture-of-Experts (MoE) layer further improves the performance of the hybrid Attention-Mamba architecture at larger scales.
Special normalization (RMSNorm) is required to stabilize training of Mamba layers at very large scales.
Explicit positional information is not needed in Jamba, as the Mamba layers likely provide implicit position information.
The authors present results on academic benchmarks, log-probability evaluations, and other tasks to support these findings. Overall, the hybrid Attention-Mamba architecture with MoE outperforms pure attention or Mamba models.
Conclusion
The paper presents Jamba, a novel architecture that combines Attention and Mamba layers with Mixture-of-Experts (MoE) modules. It provides an open implementation of Jamba, achieving state-of-the-art performance while supporting long contexts. The architecture offers flexibility in balancing performance and memory requirements while maintaining high throughput. The researchers experimented with various design choices, such as the ratio of Attention-to-Mamba layers, and discussed discoveries made during the development process, which will inform future work on hybrid attention–state-space models. The authors plan to release model checkpoints from smaller-scale training runs to facilitate further research in this area. The largest model provided with this release has 12 billion active and 52 billion total available parameters, supporting context lengths of up to 256,000 tokens and fitting on a single 80GB GPU even when processing texts up to 140,000 tokens.
If you enjoyed this summary, consider subscribing to the AImodels.fyi newsletter or following me on Twitter for more AI and machine learning content.
Posted on April 11, 2024
Join Our Newsletter. No Spam, Only the good stuff.
Sign up to receive the latest update from our blog.
Related
November 29, 2024