Transformer之后:新架构探索
AI 导读
Transformer之后:新架构探索 Mamba、RWKV、Hyena、xLSTM:突破二次复杂度瓶颈的下一代序列模型 引言...
Transformer之后:新架构探索
Mamba、RWKV、Hyena、xLSTM:突破二次复杂度瓶颈的下一代序列模型
引言
Transformer自2017年问世以来统治了几乎所有序列建模任务。然而其核心机制——自注意力(Self-Attention)——的O(n^2)时间和空间复杂度成为处理超长序列的根本瓶颈。在128K甚至1M上下文长度的需求驱动下,一系列新架构正在挑战Transformer的统治地位。它们的共同目标是:在保持建模质量的前提下,实现O(n)或近似O(n)的复杂度。
Transformer的瓶颈
复杂度分析
Self-Attention复杂度
输入序列长度: n
隐藏维度: d
计算 Q, K, V: O(n × d²) -- 线性于n
注意力分数 QK^T: O(n² × d) -- 二次于n ← 瓶颈
Softmax + 加权: O(n² × d) -- 二次于n ← 瓶颈
输出投影: O(n × d²) -- 线性于n
KV Cache (推理):
存储: O(n × d) per layer per head
总KV Cache: O(n × d × L × H)
70B模型, 128K上下文: ~40GB KV Cache
问题本质:
n=1K → QK^T: 1M 次乘法 (可接受)
n=32K → QK^T: 1G 次乘法 (昂贵)
n=128K → QK^T: 16G 次乘法 (极昂贵)
n=1M → QK^T: 1T 次乘法 (不可行)
Mamba(状态空间模型)
核心原理
Mamba是基于结构化状态空间模型(Structured State Space Model, S4)的改进。其核心思想是将序列建模视为一个连续时间系统的离散化,通过状态空间实现线性复杂度的序列处理。
状态空间模型(SSM)数学框架
连续时间:
h'(t) = A h(t) + B x(t) -- 状态转移
y(t) = C h(t) + D x(t) -- 输出映射
离散化(Zero-Order Hold):
A_bar = exp(Δ A)
B_bar = (Δ A)^{-1} (exp(Δ A) - I) Δ B
h_k = A_bar h_{k-1} + B_bar x_k -- 递推(推理)
y_k = C h_k + D x_k
关键创新(Mamba的Selective Mechanism):
B, C, Δ 均为输入依赖(input-dependent)
→ 模型可以选择性地记住或遗忘信息
→ 等价于"数据驱动的门控"
Mamba架构实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelectiveSSM(nn.Module):
"""Simplified Mamba selective state space model block."""
def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
expand: int = 2):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_inner = d_model * expand
# Input projection
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
# 1D convolution
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner, d_conv,
padding=d_conv - 1, groups=self.d_inner
)
# SSM parameters (input-dependent)
self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)
# Fixed parameter A (initialized with HiPPO)
A = torch.arange(1, d_state + 1).float()
self.A_log = nn.Parameter(torch.log(A))
self.D = nn.Parameter(torch.ones(self.d_inner))
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [batch, seq_len, d_model]
Returns:
output: [batch, seq_len, d_model]
"""
batch, seq_len, _ = x.shape
# Input projection -> (z, x)
xz = self.in_proj(x) # [B, L, 2*d_inner]
x_branch, z = xz.chunk(2, dim=-1) # Each [B, L, d_inner]
# 1D causal convolution
x_branch = x_branch.transpose(1, 2) # [B, d_inner, L]
x_branch = self.conv1d(x_branch)[:, :, :seq_len]
x_branch = x_branch.transpose(1, 2) # [B, L, d_inner]
x_branch = F.silu(x_branch)
# Input-dependent SSM parameters
x_dbl = self.x_proj(x_branch) # [B, L, 2*d_state+1]
B = x_dbl[..., :self.d_state] # [B, L, N]
C = x_dbl[..., self.d_state:2*self.d_state] # [B, L, N]
delta = F.softplus(x_dbl[..., -1:]) # [B, L, 1]
# Discretize (simplified)
A = -torch.exp(self.A_log) # [N]
A_bar = torch.exp(delta * A) # [B, L, N]
B_bar = delta * B # [B, L, N]
# Selective scan (sequential for clarity; real impl uses parallel scan)
h = torch.zeros(batch, self.d_inner, self.d_state, device=x.device)
outputs = []
for t in range(seq_len):
h = A_bar[:, t].unsqueeze(1) * h + B_bar[:, t].unsqueeze(1) * x_branch[:, t].unsqueeze(-1)
y_t = (h * C[:, t].unsqueeze(1)).sum(-1) # [B, d_inner]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [B, L, d_inner]
y = y + x_branch * self.D # Skip connection
# Gate and output
y = y * F.silu(z)
return self.out_proj(y)
Mamba的优势与局限
| 维度 | Mamba | Transformer |
|---|---|---|
| 训练复杂度 | O(n) | O(n^2) |
| 推理复杂度 | O(1) per step | O(n) per step (KV cache) |
| 长序列能力 | 天然支持 | 需要位置编码扩展 |
| 并行训练 | 需要parallel scan | 天然并行 |
| In-context learning | 较弱 | 强 |
| 精确检索 | 弱(信息压缩到固定状态) | 强(可回看所有token) |
RWKV:线性Transformer
架构特点
RWKV结合了RNN的高效推理和Transformer的并行训练,其核心是将注意力机制替换为线性递推:
RWKV核心机制:Time-Mixing与Channel-Mixing
Time-Mixing (替代Self-Attention):
r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
v_t = W_v · (mu_v ⊙ x_t + (1-mu_v) ⊙ x_{t-1})
wkv_t = (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} v_i + e^{u+k_t} v_t)
/ (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t})
o_t = W_o · (sigmoid(r_t) ⊙ wkv_t)
Channel-Mixing (替代FFN):
r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
o_t = sigmoid(r_t) ⊙ (W_v · max(k_t, 0)²)
关键特性:
训练: 可展开为并行计算 (类似Transformer)
推理: 递推形式 (O(1) per step, 无KV Cache)
Hyena:隐式长卷积
Hyena用参数化的长卷积替代注意力机制,其核心是通过可学习的卷积核实现全局信息混合:
class HyenaOperator(nn.Module):
"""Simplified Hyena operator using implicit long convolution."""
def __init__(self, d_model: int, max_len: int = 8192, order: int = 2):
super().__init__()
self.order = order
self.d_model = d_model
# Short convolution for local patterns
self.short_conv = nn.Conv1d(d_model, d_model * (order + 1),
kernel_size=3, padding=1, groups=d_model)
# Implicit parametrization of long convolution
self.filter_fn = nn.Sequential(
nn.Linear(1, 64),
nn.SiLU(),
nn.Linear(64, 64),
nn.SiLU(),
nn.Linear(64, d_model),
)
# Position encoding for filter generation
t = torch.linspace(0, 1, max_len).unsqueeze(-1)
self.register_buffer("t", t)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: [batch, seq_len, d_model]"""
batch, seq_len, _ = x.shape
# Generate projections via short convolution
x_conv = self.short_conv(x.transpose(1, 2)) # [B, D*(order+1), L]
splits = x_conv.chunk(self.order + 1, dim=1)
v = splits[0].transpose(1, 2) # [B, L, D]
# Generate long convolution filter
h = self.filter_fn(self.t[:seq_len]) # [L, D]
# Apply Hyena recurrence
y = v
for i in range(self.order):
x_i = splits[i + 1].transpose(1, 2) # [B, L, D]
# Element-wise gating
y = y * x_i
# Long convolution via FFT
y = self._fft_conv(y, h)
return y
def _fft_conv(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
"""Causal convolution via FFT for O(n log n) complexity."""
L = x.shape[1]
# Pad to avoid circular convolution artifacts
x_padded = F.pad(x.transpose(1, 2), (L, 0))
h_padded = F.pad(h.T, (0, L))
X = torch.fft.rfft(x_padded, dim=-1)
H = torch.fft.rfft(h_padded, dim=-1)
Y = X * H
y = torch.fft.irfft(Y, dim=-1)[..., :L]
return y.transpose(1, 2)
xLSTM:LSTM的现代化重生
关键创新
xLSTM通过两个关键改进让传统LSTM重新具备竞争力:
xLSTM架构
sLSTM (Scalar LSTM with exponential gating):
└── 指数门控 (exp gating) 替代sigmoid
└── 更大的遗忘能力范围
└── 标量记忆单元
mLSTM (Matrix LSTM with matrix memory):
└── 矩阵记忆单元 (d_model × d_model)
└── 类似于线性注意力的存储机制
└── 可并行训练
xLSTM = Residual Blocks of [sLSTM | mLSTM]
复杂度对比:
sLSTM: O(n) 时间, O(d) 空间 (per step)
mLSTM: O(n) 时间, O(d²) 空间 (per step)
Transformer: O(n²) 时间, O(n×d) 空间 (KV cache)
混合架构:取长补短
Jamba / Mamba-2 / Zamba
最新趋势是将Attention和SSM混合使用,在保持线性复杂度优势的同时弥补SSM在精确检索上的短板:
混合架构模式
模式A: 交替层(Jamba风格)
Layer 1: Mamba Block
Layer 2: Mamba Block
Layer 3: Attention Block ← 每N层插入一个
Layer 4: Mamba Block
Layer 5: Mamba Block
Layer 6: Attention Block ← 每N层插入一个
...
Attention占比: 1/N (通常N=3-6)
模式B: 并行分支
Input → [Mamba Branch] → ┐
→ [Attention Branch] → ┤→ Merge → Output
两个分支处理不同类型的依赖
模式C: 级联(Mamba-Attention-Mamba)
Input → Mamba(local) → Attention(global) → Mamba(refine) → Output
架构对比总结
| 架构 | 训练复杂度 | 推理复杂度 | 长序列 | ICL | 精确检索 | 成熟度 |
|---|---|---|---|---|---|---|
| Transformer | O(n^2) | O(n) per step | 受限 | 强 | 强 | 生产级 |
| Mamba/SSM | O(n) | O(1) per step | 天然 | 中 | 弱 | 早期生产 |
| RWKV | O(n) | O(1) per step | 天然 | 中 | 弱 | 早期生产 |
| Hyena | O(n log n) | O(n) | 好 | 中 | 中 | 研究级 |
| xLSTM | O(n) | O(1) per step | 天然 | 中 | 中 | 研究级 |
| 混合(Jamba) | O(n) ~ O(n^2) | 接近O(1) | 好 | 强 | 强 | 早期生产 |
选型建议
| 场景 | 推荐架构 | 理由 |
|---|---|---|
| 通用NLP/对话 | Transformer | 成熟度最高,生态最好 |
| 超长文档处理 | Mamba/混合 | 线性复杂度,128K+无压力 |
| 流式音频/信号 | Mamba/RWKV | O(1)推理,实时处理 |
| 端侧部署 | RWKV/Mamba | 无KV Cache,内存友好 |
| 需要精确召回 | Transformer/混合 | 注意力机制的核心优势 |
| 研究探索 | Mamba-2/xLSTM | 最新架构,潜力最大 |
结论
Transformer的二次复杂度瓶颈催生了一系列创新架构。Mamba以选择性状态空间模型实现了线性复杂度和强大的建模能力,RWKV以线性注意力变体实现了RNN的高效推理和Transformer的并行训练,Hyena以隐式长卷积提供了另一种全局信息混合方案,xLSTM则证明了经典架构通过现代化改造仍有竞争力。然而,这些架构在in-context learning和精确信息检索方面尚未完全匹敌Transformer,这也是为什么混合架构(如Jamba)正在成为最务实的方向——在关键层保留注意力机制,其余层使用线性复杂度的替代方案。
Maurice | maurice_wen@proton.me