架构改进
Flash Attention
痛点: 标准 Attention 的计算需要频繁在 GPU 的 HBM(显存) 和 SRAM(高速缓存) 之间搬运数据。由于显存读写速度远慢于计算速度,GPU 很多时间都在“等数据”,导致利用率低下。
Flash Attention 采用了 分块 技术。
- 它不一次性计算整个 \(L \times L\) 的注意力矩阵,而是把矩阵切成小块。
- 每一小块拉进 SRAM 后,直接在里面算完结果,再写回 HBM。
结果减少了大量数据搬运,将内存读写复杂度从 \(O(L^2)\) 降到了 \(O(L)\)。
同时在反向传播中使用了重算技术
- 标准做法: 为了求导,训练时要把巨大的 \(S\) 矩阵存在显存里。
- Flash 做法: 我不存了!反向传播时,我根据 \(Q, K, V\) 现场重新算一遍。
反直觉的结论是,现场重算的开销,竟然比从显存读取(IO)还要快。
分块原理
要理解 Flash Attention,我们不能只看“数学计算量”,必须看 GPU 的硬件架构。
在 GPU 里,有两个存储区域:
- HBM: 容量大,约80GB,但速度慢。
- SRAM : 速度极快,但容量极小,每层只有几百 KB 到几 MB。
标准 Attention 的笨办法:
- 从 HBM 读取 \(Q, K\) 到 SRAM,计算 \(QK^T = S\) ,即注意力的分数矩阵。
- 必须把巨大的 \(S\) 矩阵写回 HBM,因为 SRAM 放不下整个 \(L \times L\) 的矩阵。
- 再从 HBM 把 \(S\) 读回来,做 Softmax。
- 再把结果写回 HBM。
- 再把 \(S\) 读回来,乘以 \(V\)...
发现问题了吗? GPU 的计算核心(CUDA Core)其实转得飞快,但它大部分时间都在打哈欠,等待数据从慢速的 HBM 搬运过来。这在工程上叫 Memory-bound(内存带宽受限)。
Flash Attention 的核心逻辑是:既然搬运慢,那我就尽量不出 SRAM,在里面把活儿全干完。
它把 \(Q, K, V\) 矩阵切成一个个小方块,小到足以塞进 SRAM。
- 它从 HBM 读入一小块 \(Q\) 和一小块 \(K\)。
- 在 SRAM 里直接算出这一小块的 \(Attention(Q, K, V)\) 结果。
- 关键点: 它不需要把那个巨大的中间矩阵 \(S\) (\(L \times L\)) 写回显存。它只把最终计算好的这一小块结果写回去。
Online Softmax
你可能会问:“Softmax 需要知道全句的最大值和总和,只看一小块怎么算?”
- Flash Attention 用了一个数学技巧:它在分块计算时,动态地更新最大值和累加和。
- 通过不断迭代修正,最后的结果和全局计算的一模一样。
重算
Flash Attention 是怎么“反向传播”的?
- 前向传播时: 我只把 \((L, D)\) 这么大的 \(Q, K, V\) 存在显存里。那个 \((L, L)\) 的 \(S\) 算完就扔,绝不写回显存。
- 反向传播到这一层时: 我从显存里读取 \((L, D)\) 的 \(Q, K, V\)。
- 在这一层内部,我重新把它们切成小块(Tiles),在 SRAM 里现场复现出一小块 \(S\)。
- 用这一小块 \(S\) 算完梯度,立刻抹除,再算下一块。
为什么 \(Q, K, V\) 比 \(S\) 小得多?
我们拿 Llama-3-8B 的真实参数,算一下长文本(32k)情况下的数据量:
- 假设 \(L = 32,768\) (3.2万个词)
- 假设 \(D = 4096\)
\(Q, K, V\) 的大小(线性增长):
每个词都有自己的 \(Q, K, V\) 向量。
- \(Q\) 矩阵的形状是 \((L, D)\) \(\rightarrow\) \(32,768 \times 4,096 \approx\) 1.34 亿个元素。
\(S\) 矩阵(注意力分数矩阵)的大小(平方增长):
\(S\) 是每个词都要跟其他所有词打分,形状是 \((L, L)\)。
- \(S\) 矩阵的形状是 \((32,768, 32,768) \approx\) 10.7 亿个元素。
在这个例子里,\(S\) 的元素数量是 \(Q\) 的 8 倍。
更恐怖的是:如果上下文增加到 128k: - \(Q, K, V\) 的大小只增加了 4 倍。 - \(S\) 的大小增加了 \(4 \times 4 =\) 16 倍!
- 显存占用: 128k 的 \(S\) 矩阵如果用 FP16 存储,光这一个中间矩阵就要占 32GB 显存。这还没算梯度和其他层的开销,单张显卡直接就爆了。
这就是为什么 Flash Attention 要“消灭” \(S\): 因为 \(Q, K, V\) 这种 \((L, D)\) 形状的矩阵,显存还勉强能背得动;但 \(S\) 这种 \((L, L)\) 的大方阵,长度一长就是所有显卡的噩梦。
Mixture of Experts (MoE)
如果说增加参数量能让模型变强,那 MoE 就是在不增加单次推理开销的前提下,把参数量推到万亿级的“黑科技”。
- 核心组件:
- Experts (专家): 多个并行的小型 FFN 网络。
- Router (路由/门控): 一个分类器,决定当前的 Token 该交给哪 1-2 个专家处理。
- 优势: 激活参数量 \(\neq\) 总参数量。比如一个 1.8T 的模型(如 GPT-4 传闻),每次推理可能只激活了 2 个专家,计算量可能只相当于一个 100B 的模型。
- 面试考点(负载均衡): 如果所有 Token 都去找同一个“明星专家”,其他专家就会闲死。因此需要加入 Auxiliary Loss(辅助损失),强制让路由把任务分得均匀点。
在 MoE 的演进过程中,架构的进化始终围绕着三个核心矛盾:计算效率(稀疏性)、专家利用率(负载均衡) 和 知识冗余(泛化性)。
以下是 MoE 发展史上的几个经典里程碑架构:
Shazeer MoE (Sparsely-Gated MoE)
背景: 2017 年由 Google 提出,是现代深度学习中稀疏 MoE 的鼻祖。
- 解决问题: 如何在不增加计算量的前提下,成倍扩大模型容量。
- 关键技术:
- Top-k Gating: 引入门控网络,对每个 Token 只激活前 \(k\) 个专家(通常 \(k=1\) 或 \(2\))。
- 噪声注入 (Noisy Top-k): 在门控激活前加入高斯噪声,防止“马太效应”(即强者愈强,导致某些专家从未被训练)。
- 局限性: 训练极不稳定;通信开销巨大,难以在大规模分布式集群上高效运行。
Switch Transformer
背景: 2021 年 Google 提出,将模型参数推向了万亿级(1.6 Trillion)。
- 解决问题: 进一步简化 MoE,降低通信和计算复杂度。
- 关键技术:
- Switch Routing (Top-1): 极端简化,每个 Token 只发给 1个 专家。这极大地减少了路由的计算和通信带宽。
- 容量因子 (Capacity Factor): 限制每个专家能处理的最大 Token 数,多出的部分直接丢弃或传给下一层。
- 局限性: Top-1 策略虽然快,但性能损失明显(相比 Top-2),模型容易陷入不收敛的窘境。
GLaM (Generalist Language Model)
背景: Google 在 2022 年提出的架构,证明了 MoE 在零样本学习上的强大。
- 解决问题: 如何在超大规模预训练中平衡性能与推理能效。
- 关键技术:
- 专家交替: 并不是每一层都是 MoE,而是每隔一层放置一个 MoE 层,中间穿插普通的稠密 FFN 层。
- 局限性: 依然面临显存占用的巨大压力(需要加载所有专家参数)。
Mixtral 8x7B (MoE for the Masses)
背景: Mistral AI 在 2023 年开源,彻底引爆了开源社区对 MoE 的热情。
- 解决问题: 如何让中等规模的模型拥有超越其尺寸的性能。
- 关键技术:
- Sliding Window Attention + MoE: 结合了长文本处理能力和 MoE 的推理效率。
- 精细化的并行优化: 针对消费级显卡优化了路由算法,让 8x7B 模型在性能上对标 Llama 2 70B,但推理速度快得多。
- 局限性: 专家数量固定(8个),且每个专家都是完整的 FFN,缺乏更细粒度的知识划分。
DeepSeek-V2 / V3 (Multi-head Latent Attention & DeepSeekMoE)
背景: 近两年最强国产开源模型,代表了目前 MoE 的巅峰设计。
- 解决问题: 解决传统 MoE 中的专家冗余和知识退化问题。
- 关键技术:
- 细粒度专家 (Fine-Grained Experts): 将传统的几个大专家拆成几十个小专家,提高知识组合的灵活性。
- 共享专家 (Shared Experts): 始终激活一部分专家处理所有 Token。这些共享专家负责存储“公共常识”,而专用专家负责“专项知识”,有效缓解了不同专家之间重复学习常识的问题。
- 无损负载均衡: 采用辅助损失函数之外的动态调整策略,解决了专家闲置问题。
| 架构 | 核心贡献 | 缺点 |
|---|---|---|
| Shazeer MoE | 定义了稀疏门控逻辑 | 训练极不稳定,同步难 |
| Switch Transformer | 验证了 Top-1 路由的极致速度 | 精度有所妥协,容易丢弃 Token |
| Mixtral | 证明了 MoE 在中型模型上的降维打击 | 架构较为传统,专家利用率一般 |
| DeepSeekMoE | 共享+细粒度专家,知识利用率最高 | 路由逻辑复杂,对推理调度要求极高 |
Sparse Attention (稀疏注意力)
随着模型需要处理的文本越来越长,传统的注意力机制已经成为了性能瓶颈。在标准的 Transformer(如 GPT-2, BERT)中,使用的是全注意力。每个 Token 都要和序列中所有其他 Token 计算关联度。如果序列长度为 \(N\),计算量和内存占用会呈 \(O(N^2)\) 平方级增长。当你想让模型读一整本书时,显存会瞬间溢出,计算速度也会慢得无法接受。
稀疏注意力的目标: 将复杂度从 \(O(N^2)\) 降低到 \(O(N \log N)\) 甚至 \(O(N)\),从而支持超长文本的处理。
模式化稀疏
通过预设的几何规则来限制注意力的范围,比如限制模型只能观察固定窗口内的上下文。不过显然会有一些问题,因为这样预设的窗口太过死板了,其实是给了模型很强的先验知识,类似于卷积网络的局部性假设,而transformer一般是不依靠先验知识,只靠海量数据来学习的。全局注意力几乎没有任何归纳偏置,他假设任何两个词之间都可能存在联系,而人为规定可能会限制模型的能力。
具体实现
为了打破 \(O(N^2)\) 的限制,模式化稀疏通常由以下三种机制组合而成:
A. 滑动窗口(Sliding Window / Local Attention)
- 实现: 规定每个 Token 只能看到其左右半径为 \(w\) 的邻居。
- 逻辑: 效仿人类阅读,理解一个词通常只需要上下文。这让计算量从 \(N \times N\) 降到了 \(N \times w\)(线性复杂度)。
B. 扩张窗口(Dilated Window / Strided Attention)
- 实现: 类似于空洞卷积。Token 不再看连续的邻居,而是每隔 \(k\) 个词看一个。
- 逻辑: 在不增加计算量的前提下,通过“跳跃”感知更远的背景信息。
C. 全局锚点(Global Attention / BigBird)
- 实现: 选定极少数 Token(如第一个词或专门增加的虚假 Token)与序列中所有词进行计算。
- 逻辑: 它们充当“信息班车”,把序列末尾的信息带回开头。
为什么这种死板的设计会限制模型?
正如刚刚所说的,这种硬性的规则会带来三个严重的问题:
① 割裂了非局部的语义联系
在自然语言中,存在大量的长程依赖(Long-range Dependencies)。
例子: “张三在清华大学读完了本科,随后去美国深造,拿到了博士学位,回国后他成为了教授。”
如果窗口大小只有 10 个词,当模型处理到“教授”时,由于滑动窗口的限制,它可能已经“看不见”句首的“张三”了。虽然信息可以通过多层堆叠逐步传递,但路径变长了,信号会衰减。
② 破坏了 Transformer 的“涌现”能力
Transformer 的强大之处在于它能通过数据自己发现规律。
- 全注意力: 模型可以学习到“第 5 个词对第 500 个词非常重要”。
- 模式化稀疏: 你直接告诉模型“第 5 个词绝对不能看第 500 个词”。这种强人工干预限制了模型在海量数据中挖掘潜在复杂关联的可能性。
③ 感受野衰减(Receptive Field)
虽然增加层数可以扩大等效感受野,但在处理超长序列(如 10 万长度)时,顶层 Token 想要通过逐层传递获取底层远端信息,效率极低。
既然它这么死板,为什么像 Longformer、BigBird 甚至最近的一些模型还在参考这种思路?
- 硬件的妥协: 它是目前最容易在 NVIDIA 显卡上实现加速的方案。通过将注意力限制在固定的小块,可以使用
Block-sparse算子,运行速度极快。 - 自然的语言特性: 在很多 NLP 任务中,80% 的关键信息确实隐藏在局部上下文里。这种“死板”的先验虽然牺牲了 20% 的特殊情况,但换取了处理 100 倍长度文本的能力。
- 折中方案: 现在的模型倾向于在浅层使用稀疏注意力节省资源,而在深层保留全注意力来捕捉全局逻辑,或者使用 FlashAttention 这种不改变逻辑只优化 IO 的技术。
基于内容的稀疏
在 Standard Transformer 中,我们要计算 \(Q\) 和 \(K\) 的乘积。如果 \(Q\)(查询)和 \(K\)(键)的向量在空间中距离很远,它们的点积就很小,注意力权重几乎为 0。
基于内容的稀疏注意力认为: 既然大多数计算结果最后都是 0,那我们为什么不先进行筛选,只对那些“可能产生大点积”的 \(Q\) 和 \(K\) 进行计算?
基于哈希的稀疏:Reformer (LSH Attention)
LSH (Locality Sensitive Hashing,局部敏感哈希) 是这种架构的核心黑科技。
实现原理:
- 投影分桶: 想象空间中有无数个点(Token 向量),我们随机画几根线(随机旋转),将空间切成若干个“扇区”(桶)。
- 哈希签名: 距离很近的点有极大概率掉进同一个桶里。
- 只算桶内: 我们只让同一个桶里的 \(Q\) 和 \(K\) 进行点积计算。
如何解决死板问题: 它没有规定看“位置”,而是看“语义”。如果第 1 个词和第 10000 个词被哈希到了同一个桶,它们依然能直接建立联系。
局限性: 哈希具有随机性,偶尔会把原本相关的两个点分到不同的桶里(即“漏掉”了重要的注意力)。
基于聚类的稀疏:Routing Transformer
这种架构比哈希更进一步,它利用了无监督学习中的经典算法——K-means 聚类。
实现原理:
- 全局聚类: 在计算注意力之前,先对所有的 \(Q\) 和 \(K\) 进行聚类,划分为 \(k\) 个簇(Cluster)。
- 路由引导: 每个 \(Q\) 只会被“路由”到属于同一个簇的 \(K\) 上。
- 动态分配: 随着模型每一层的学习,向量的表示在变,聚类的结果也在变。
如何解决死板问题: 这是一种动态稀疏。模型在每一层都在重新学习“哪些信息应该聚在一起”。
局限性: 聚类算法(如 K-means)本身计算量就不小,如果序列极长,聚类本身可能成为新的瓶颈。
为什么这种“高级”方案在当下反而没火?
虽然它们在数学哲学上比“滑动窗口”更优美,但目前的工业界大模型(如 Llama, GPT-4)却很少直接采用它们,原因很现实:
- 算力不对称:在 GPU 上,“杂乱但量小”的计算(稀疏访问)往往比“整齐但量大”的计算(稠密矩阵乘法)还要慢。哈希和聚类产生的稀疏索引是非常随机的,显存读取效率(Memory Bandwidth)极低。
- 梯度传播的稳定性:哈希和聚类本质上包含一种“硬选择”(要么在这个桶,要么不在)。这种不可导的操作会给深度网络的梯度反向传播带来麻烦,导致训练不容易收敛。
- FlashAttention 的降维打击:2022 年出现的 FlashAttention 证明了:我们不需要为了省计算量而去搞复杂的哈希。只要优化好 GPU 缓存的读写逻辑,直接暴力算全注意力也能飞快。
线性注意力
在标准注意力机制中,公式是:
- 痛点: 必须先让 \(Q\) 和 \(K^T\) 相乘,得到一个 \(N \times N\) 的巨大分数矩阵,然后再乘以 \(V\)。
- 线性注意力的思路: 如果我们能把公式变成 \(Q(K^TV)\),那么 \(K^TV\) 的计算量只跟特征维度 \(d\) 有关,与序列长度 \(N\) 呈线性关系!
但问题是:Softmax 挡在中间,导致矩阵乘法的结合律失效了。
技术实现:核函数(Kernel Trick)
为了去掉讨厌的 Softmax,线性注意力引入了核函数的概念(代表作:Performer, Linear Transformer)。本质是用核函数来计算向量间的相似度,而且符合结合律,用数学的近似来换取计算效率。
特征映射: 找一个函数 \(\phi(\cdot)\),使得 \(\text{sim}(q, k) \approx \phi(q)\phi(k)^T\)。
重写公式:
利用结合律:
计算优势: \(\phi(K)^T V\) 的结果是一个 \(d \times d\) 的小矩阵(通常 \(d=64\) 或 \(128\))。
- 然后再用 \(\phi(Q)\) 乘以这个小矩阵。
- 整个过程不需要生成 \(N \times N\) 的矩阵,显存占用极低。
优点:
- 全域视野: 它不像“滑动窗口”那样死板,每个 Token 理论上依然能感知到序列中任何位置的信息。
- 极速推理: 随着文本变长,它的计算开销增加非常缓慢(线性增长)。
- RNN 化: 它可以被改写成类似 RNN 的形式,实现“流式”处理,这意味着处理无限长的文本时,内存不会爆炸。
致命伤:
- 注意力“平滑化”: Softmax 的作用是“挑尖”,即让重要的信息非常突出,不重要的压到接近 0。线性注意力(使用核函数近似)往往会导致权重变得很平淡,模型分不清重点,精度通常不如标准 Transformer。
- 因果掩码(Causal Masking)实现复杂: 在生成任务中,模型不能看未来的词。在线性注意力下实现这个限制需要复杂的数学技巧(前缀和优化),容易导致数值不稳定。
- 训练难: 这种架构对学习率非常敏感,容易梯度消失或爆炸。
Note
稀疏注意力在实际工程中的一些缺点:
- 硬件不友好(The Efficiency Gap): GPU 最擅长处理规整的密矩阵运算。稀疏矩阵往往导致大量的随机内存访问,导致“理论计算量低,实际运行速度慢”。
- 精度损失(Information Loss): 强制切断某些 Token 之间的联系可能会导致模型丢失长程依赖。例如,在法律文档中,第 1 页的定义可能决定了第 100 页的解释,稀疏策略可能会漏掉这个连接。
- 实现复杂: 动态稀疏(如聚类)在训练过程中非常不稳定,且难以编写高效的 CUDA 内核。
现状:谁赢了?
目前在工业界(如 GPT-4, Claude 3, Gemini),稀疏注意力并没有完全统治,取而代之的是几种更均衡的方案:
- FlashAttention: 并不改变注意力的数学本质(依然是全注意力),但通过优化 GPU 读写(IO Aware)极大地提升了速度。
- GQA (Grouped Query Attention): 通过共享 Key/Value 头来减小显存占用(Llama 3 采用)。
- 混合架构: 如 Jamba 或 Griffin,结合了 Transformer 和线性循环神经网络(RNN/Mamba)。
长上下文扩展
除了 RoPE(通过角度旋转处理相对位置),面试中经常会对比 ALiBi。
ALiBi (Attention with Linear Biases)
- 做法: 不在 Embedding 里加位置信息,而是在计算 Attention Score 时,直接给 \(QK^T\) 减去一个正比于距离的惩罚项。
- 优点: 外推性极强。模型在 1k 长度训练,推理时直接喂 10k 也没问题,因为它对远距离词的感知是线性衰减的。
RoPE 的扩展 (Position Interpolation - PI)
- 痛点: 训练是 2048 长度,现在想跑 4096。
- 方案: 把 4096 个词的索引“挤一挤”,缩放到 0-2047 的范围内。这叫线性插值。
- 进化: 后来又有了 YaRN 或 NTK-aware Scaled RoPE,让高频维度和低频维度以不同比例缩放,效果更好。