Transformer之后:新架构探索

Mamba、RWKV、Hyena、xLSTM:突破二次复杂度瓶颈的下一代序列模型

引言

Transformer自2017年问世以来统治了几乎所有序列建模任务。然而其核心机制——自注意力(Self-Attention)——的O(n^2)时间和空间复杂度成为处理超长序列的根本瓶颈。在128K甚至1M上下文长度的需求驱动下,一系列新架构正在挑战Transformer的统治地位。它们的共同目标是:在保持建模质量的前提下,实现O(n)或近似O(n)的复杂度。

Transformer的瓶颈

复杂度分析

Self-Attention复杂度

输入序列长度: n
隐藏维度: d

计算 Q, K, V:     O(n × d²)     -- 线性于n
注意力分数 QK^T:   O(n² × d)     -- 二次于n ← 瓶颈
Softmax + 加权:    O(n² × d)     -- 二次于n ← 瓶颈
输出投影:          O(n × d²)     -- 线性于n

KV Cache (推理):
  存储: O(n × d) per layer per head
  总KV Cache: O(n × d × L × H)
  70B模型, 128K上下文: ~40GB KV Cache

问题本质:
  n=1K   → QK^T: 1M 次乘法 (可接受)
  n=32K  → QK^T: 1G 次乘法 (昂贵)
  n=128K → QK^T: 16G 次乘法 (极昂贵)
  n=1M   → QK^T: 1T 次乘法 (不可行)

Mamba(状态空间模型)

核心原理

Mamba是基于结构化状态空间模型(Structured State Space Model, S4)的改进。其核心思想是将序列建模视为一个连续时间系统的离散化,通过状态空间实现线性复杂度的序列处理。

状态空间模型(SSM)数学框架

连续时间:
  h'(t) = A h(t) + B x(t)     -- 状态转移
  y(t)  = C h(t) + D x(t)     -- 输出映射

离散化(Zero-Order Hold):
  A_bar = exp(Δ A)
  B_bar = (Δ A)^{-1} (exp(Δ A) - I) Δ B

  h_k = A_bar h_{k-1} + B_bar x_k    -- 递推(推理)
  y_k = C h_k + D x_k

关键创新(Mamba的Selective Mechanism):
  B, C, Δ 均为输入依赖(input-dependent)
  → 模型可以选择性地记住或遗忘信息
  → 等价于"数据驱动的门控"

Mamba架构实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelectiveSSM(nn.Module):
    """Simplified Mamba selective state space model block."""

    def __init__(self, d_model: int, d_state: int = 16, d_conv: int = 4,
                 expand: int = 2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_inner = d_model * expand

        # Input projection
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)

        # 1D convolution
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner, d_conv,
            padding=d_conv - 1, groups=self.d_inner
        )

        # SSM parameters (input-dependent)
        self.x_proj = nn.Linear(self.d_inner, d_state * 2 + 1, bias=False)

        # Fixed parameter A (initialized with HiPPO)
        A = torch.arange(1, d_state + 1).float()
        self.A_log = nn.Parameter(torch.log(A))

        self.D = nn.Parameter(torch.ones(self.d_inner))
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: [batch, seq_len, d_model]
        Returns:
            output: [batch, seq_len, d_model]
        """
        batch, seq_len, _ = x.shape

        # Input projection -> (z, x)
        xz = self.in_proj(x)  # [B, L, 2*d_inner]
        x_branch, z = xz.chunk(2, dim=-1)  # Each [B, L, d_inner]

        # 1D causal convolution
        x_branch = x_branch.transpose(1, 2)  # [B, d_inner, L]
        x_branch = self.conv1d(x_branch)[:, :, :seq_len]
        x_branch = x_branch.transpose(1, 2)  # [B, L, d_inner]
        x_branch = F.silu(x_branch)

        # Input-dependent SSM parameters
        x_dbl = self.x_proj(x_branch)  # [B, L, 2*d_state+1]
        B = x_dbl[..., :self.d_state]  # [B, L, N]
        C = x_dbl[..., self.d_state:2*self.d_state]  # [B, L, N]
        delta = F.softplus(x_dbl[..., -1:])  # [B, L, 1]

        # Discretize (simplified)
        A = -torch.exp(self.A_log)  # [N]
        A_bar = torch.exp(delta * A)  # [B, L, N]
        B_bar = delta * B  # [B, L, N]

        # Selective scan (sequential for clarity; real impl uses parallel scan)
        h = torch.zeros(batch, self.d_inner, self.d_state, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = A_bar[:, t].unsqueeze(1) * h + B_bar[:, t].unsqueeze(1) * x_branch[:, t].unsqueeze(-1)
            y_t = (h * C[:, t].unsqueeze(1)).sum(-1)  # [B, d_inner]
            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)  # [B, L, d_inner]
        y = y + x_branch * self.D  # Skip connection

        # Gate and output
        y = y * F.silu(z)
        return self.out_proj(y)

Mamba的优势与局限

维度 Mamba Transformer
训练复杂度 O(n) O(n^2)
推理复杂度 O(1) per step O(n) per step (KV cache)
长序列能力 天然支持 需要位置编码扩展
并行训练 需要parallel scan 天然并行
In-context learning 较弱
精确检索 弱(信息压缩到固定状态) 强(可回看所有token)

RWKV:线性Transformer

架构特点

RWKV结合了RNN的高效推理和Transformer的并行训练,其核心是将注意力机制替换为线性递推:

RWKV核心机制:Time-Mixing与Channel-Mixing

Time-Mixing (替代Self-Attention):
  r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
  k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
  v_t = W_v · (mu_v ⊙ x_t + (1-mu_v) ⊙ x_{t-1})

  wkv_t = (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} v_i + e^{u+k_t} v_t)
         / (sum_{i=1}^{t-1} e^{-(t-1-i)w+k_i} + e^{u+k_t})

  o_t = W_o · (sigmoid(r_t) ⊙ wkv_t)

Channel-Mixing (替代FFN):
  r_t = W_r · (mu_r ⊙ x_t + (1-mu_r) ⊙ x_{t-1})
  k_t = W_k · (mu_k ⊙ x_t + (1-mu_k) ⊙ x_{t-1})
  o_t = sigmoid(r_t) ⊙ (W_v · max(k_t, 0)²)

关键特性:
  训练: 可展开为并行计算 (类似Transformer)
  推理: 递推形式 (O(1) per step, 无KV Cache)

Hyena:隐式长卷积

Hyena用参数化的长卷积替代注意力机制,其核心是通过可学习的卷积核实现全局信息混合:

class HyenaOperator(nn.Module):
    """Simplified Hyena operator using implicit long convolution."""

    def __init__(self, d_model: int, max_len: int = 8192, order: int = 2):
        super().__init__()
        self.order = order
        self.d_model = d_model

        # Short convolution for local patterns
        self.short_conv = nn.Conv1d(d_model, d_model * (order + 1),
                                     kernel_size=3, padding=1, groups=d_model)

        # Implicit parametrization of long convolution
        self.filter_fn = nn.Sequential(
            nn.Linear(1, 64),
            nn.SiLU(),
            nn.Linear(64, 64),
            nn.SiLU(),
            nn.Linear(64, d_model),
        )

        # Position encoding for filter generation
        t = torch.linspace(0, 1, max_len).unsqueeze(-1)
        self.register_buffer("t", t)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: [batch, seq_len, d_model]"""
        batch, seq_len, _ = x.shape

        # Generate projections via short convolution
        x_conv = self.short_conv(x.transpose(1, 2))  # [B, D*(order+1), L]
        splits = x_conv.chunk(self.order + 1, dim=1)
        v = splits[0].transpose(1, 2)  # [B, L, D]

        # Generate long convolution filter
        h = self.filter_fn(self.t[:seq_len])  # [L, D]

        # Apply Hyena recurrence
        y = v
        for i in range(self.order):
            x_i = splits[i + 1].transpose(1, 2)  # [B, L, D]
            # Element-wise gating
            y = y * x_i
            # Long convolution via FFT
            y = self._fft_conv(y, h)

        return y

    def _fft_conv(self, x: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
        """Causal convolution via FFT for O(n log n) complexity."""
        L = x.shape[1]
        # Pad to avoid circular convolution artifacts
        x_padded = F.pad(x.transpose(1, 2), (L, 0))
        h_padded = F.pad(h.T, (0, L))

        X = torch.fft.rfft(x_padded, dim=-1)
        H = torch.fft.rfft(h_padded, dim=-1)
        Y = X * H
        y = torch.fft.irfft(Y, dim=-1)[..., :L]

        return y.transpose(1, 2)

xLSTM:LSTM的现代化重生

关键创新

xLSTM通过两个关键改进让传统LSTM重新具备竞争力:

xLSTM架构

sLSTM (Scalar LSTM with exponential gating):
  └── 指数门控 (exp gating) 替代sigmoid
  └── 更大的遗忘能力范围
  └── 标量记忆单元

mLSTM (Matrix LSTM with matrix memory):
  └── 矩阵记忆单元 (d_model × d_model)
  └── 类似于线性注意力的存储机制
  └── 可并行训练

xLSTM = Residual Blocks of [sLSTM | mLSTM]

复杂度对比:
  sLSTM: O(n) 时间, O(d) 空间 (per step)
  mLSTM: O(n) 时间, O(d²) 空间 (per step)
  Transformer: O(n²) 时间, O(n×d) 空间 (KV cache)

混合架构:取长补短

Jamba / Mamba-2 / Zamba

最新趋势是将Attention和SSM混合使用,在保持线性复杂度优势的同时弥补SSM在精确检索上的短板:

混合架构模式

模式A: 交替层(Jamba风格)
  Layer 1: Mamba Block
  Layer 2: Mamba Block
  Layer 3: Attention Block ← 每N层插入一个
  Layer 4: Mamba Block
  Layer 5: Mamba Block
  Layer 6: Attention Block ← 每N层插入一个
  ...
  Attention占比: 1/N (通常N=3-6)

模式B: 并行分支
  Input → [Mamba Branch] → ┐
       → [Attention Branch] → ┤→ Merge → Output
  两个分支处理不同类型的依赖

模式C: 级联(Mamba-Attention-Mamba)
  Input → Mamba(local) → Attention(global) → Mamba(refine) → Output

架构对比总结

架构 训练复杂度 推理复杂度 长序列 ICL 精确检索 成熟度
Transformer O(n^2) O(n) per step 受限 生产级
Mamba/SSM O(n) O(1) per step 天然 早期生产
RWKV O(n) O(1) per step 天然 早期生产
Hyena O(n log n) O(n) 研究级
xLSTM O(n) O(1) per step 天然 研究级
混合(Jamba) O(n) ~ O(n^2) 接近O(1) 早期生产

选型建议

场景 推荐架构 理由
通用NLP/对话 Transformer 成熟度最高,生态最好
超长文档处理 Mamba/混合 线性复杂度,128K+无压力
流式音频/信号 Mamba/RWKV O(1)推理,实时处理
端侧部署 RWKV/Mamba 无KV Cache,内存友好
需要精确召回 Transformer/混合 注意力机制的核心优势
研究探索 Mamba-2/xLSTM 最新架构,潜力最大

结论

Transformer的二次复杂度瓶颈催生了一系列创新架构。Mamba以选择性状态空间模型实现了线性复杂度和强大的建模能力,RWKV以线性注意力变体实现了RNN的高效推理和Transformer的并行训练,Hyena以隐式长卷积提供了另一种全局信息混合方案,xLSTM则证明了经典架构通过现代化改造仍有竞争力。然而,这些架构在in-context learning和精确信息检索方面尚未完全匹敌Transformer,这也是为什么混合架构(如Jamba)正在成为最务实的方向——在关键层保留注意力机制,其余层使用线性复杂度的替代方案。


Maurice | maurice_wen@proton.me