Page 1 of 1

SAMBA 混合语言模型:关键概念解释

Posted: Tue Dec 03, 2024 3:52 am
by urrifat77
作为经常使用 OpenAI 的 ChatGPT 和 Anthropic 的 Claude 等高级模型的人,我亲眼观察到它们的性能如何随着输入提示长度的增加而下降,从而导致在保持扩展文本的连贯性和相关性方面出现问题。

为了解决这些限制,微软和伊利诺伊大学的研究人员推出了SAMBA——一种结合状态空间模型 (SSM) 与滑动窗口注意 (SWA) 的新型混合架构。

他们的方法充分利用了 SSM(擅长管理长期依赖关系)和 SWA(处理上下文窗口并保持计算可处理性)的优势。通过融合这些技术,SAMBA 实现了高效的语言建模,上下文长度几乎不受限制。

在本文中,我们将探索 SAMBA 的架构及其在不丢失上下文的情况下处理长文本跨度的独特能力。我们还强调了其显著增强语言模型处理和生成扩展序列的能力的潜力,为语言建模树立了新标准。

理解上下文瓶颈
要理解 SAMBA 为何如此创新,我们首先需要了解传统语言模型在处理长文本序列时面临的挑战。

Transformer 中的有限上下文
传统的基于 Transformer 的模型虽然功能强大,但在处理长文本序列时却面临巨大挑战,因为它们在上下文长度方面的复杂度为二次方。这种二次方复杂度源于自注意力机制,该机制要求每个标记关注序列中的每个其他标记。

因此,计算和内存成本随着序列长度的增加而迅速增长,使得这些模型对于需要处理非常长文本的任务来说不切实际。

这种限制常常迫使我们截断输入或使用其他次优策略来适应可用硬件的限制。最终,这种妥协降低了模型在扩展序列上保持性能的能力,这是我在开发需要处理长文档的应用程序时遇到的一个挑战。

SSM 及其局限性
状态空间模型 (SSM) 提供了一种具有线性复杂度的替代方案,使其在处理长序列时具有更高的计算效率。SSM 保持不断发展的状态,使其能够处理扩展的依赖关系,而无需使用成本过高的转换器。

然而,SSM 也有其局限性。由于其马尔可夫性质,即当前状态仅取决于先前状态,它们通常难以在长序列上回忆。这种有限的回忆降低了它们在综合上下文建模中的有效性,特别是在需要保留和参考序列中更早的信息的应用中。

需要混合方法
鉴于 Transformers 和 SSM 各有优缺点,迫切需要一种能够充分利用各自优势并减轻其局限性的混合方法。将 SSM 与注意力机制相结合是一种有前途的解决方案。

这种混合方法利用了 SSM 的效率和长程依赖性以及 Transformers 的动态和集中注意力机制。通过整合这两种方法,我们可以创建一个能够流畅处理长序列并增强记忆回忆和上下文理解的模型。

SAMBA:简单混合状态空间模型
SAMBA 为这一上下文瓶颈提供了一个优雅的解决方案,结合了两种不同方法的优势。

核心理念
SAMBA 背后的核心思想是将Mamba(一种 SSM)与 SwiGLU 和滑动窗口注意 (SWA) 层交错。这种混合结构既能捕捉循环结构,又能精确检索记忆。

SAMBA 通过结合 SSM 和注意力 rcs 数据库 机制的优势来管理长上下文同时保留详细信息,从而体现了这种方法。

该图说明了 SAMBA 架构,其中 Mamba、MLP、SWA、MLP 模式重复 N/4 次,其中 N 是层数。

SAMBA 架构图。来源:Ren 等人(2024 年)

Mamba 层
SAMBA 中的 Mamba 层擅长捕捉时间相关语义,为处理顺序数据提供强大的框架。这些层通过维护和更新反映数据内时间依赖性的状态来运行。

Mamba 通过利用选择性状态空间来实现这一点,这些空间允许模型专注于相关输入并在长序列中保留重要信息。这种选择性门控机制对于快速解码至关重要,可确保模型能够以高精度和最小计算开销解释和预测序列模式。

SWA 层
滑动窗口注意层通过解决有限上下文窗口内的复杂非马尔可夫依赖关系来补充 Mamba 层。SWA 以在输入序列上滑动的窗口大小进行操作,确保线性计算复杂度。

这使得模型能够从中期到短期历史中检索高清信号,而这些信号无法被 Mamba 的循环状态捕获。通过动态调整其焦点,SWA 层使模型能够保持连贯性和上下文,特别是对于需要对长输入做出上下文相关响应的任务。

SwiGLU 层
SAMBA 中的 SwiGLU 层有助于非线性转换并增强知识回忆。这些层将非线性引入模型,使其能够捕获数据中更复杂的模式和交互。

此外,SwiGLU 层确保模型能够处理和调用信息,从而提高其稳健性和多功能性。这种非线性转换对于模型从训练数据推广到实际应用的能力至关重要。

SAMBA:性能和可扩展性
探索了 SAMBA 的架构后,现在让我们与其他模型相比来检查一下它的性能和效率。

基准测试表现强劲
SAMBA 在各种语言理解和推理基准测试中表现出色,优于纯基于注意力和基于 SSM 的模型。具体来说,SAMBA 已在 MMLU、GSM8K 和 HumanEval 等任务上进行了评估,MMLU 得分为 71.2,GSM8K 得分为 69.6,HumanEval 得分为 54.9。

SAMBA 性能基准

来源:Ren 等(2024 年)

这些结果显著超越了其他最先进的模型,包括 TFM++ 和 Llama-3,展示了 SAMBA 处理各种语言理解任务的能力。例如,与 TFM++ 相比,SAMBA 在 GSM8K 中的准确率高出 18.1%,凸显了其将 SSM 与注意力机制相结合的混合架构的熟练程度。

高效长度外推
SAMBA 最显著的特点之一是它能够处理明显更长的上下文长度,同时保持效率。尽管 SAMBA 是在 4K 长度的序列上进行预训练的,但它可以推断出最多 1M 个标记,并且困惑度有所改善,同时仍保持线性解码时间复杂度。

这是通过将 Mamba 的选择性状态空间与 SWA 逐层组合实现的,使得模型能够在没有二次计算复杂度的情况下保持高性能。

Image

实际上,与 Llama-3 架构相比,SAMBA 实现了 3.64 倍的解码吞吐量,特别是对于长度高达 128K 标记的序列,证明了其可扩展性和处理长上下文的能力。

提高记忆回忆能力
与纯 SSM 相比,SAMBA 的混合架构显著增强了其记忆回忆能力。在诸如 Passkey Retrieval 之类的测试中,SAMBA 在仅经过 500 步微调后,就几乎完美地实现了长达 256K 上下文长度的记忆回忆,而基于 SWA 的模型在超过 4K 长度时就遇到了困难。

这种出色的性能归功于 Mamba 的时间相关语义循环结构和 SWA 的记忆检索能力的综合优势。因此,SAMBA 在短期和长期记忆回忆任务中表现出色,使其成为需要广泛上下文理解的应用程序的强大解决方案。