1. 从顺序累加到并行扫描:理解状态空间模型的计算革命
第一次接触状态空间模型时,我被它优雅的数学形式所吸引——直到尝试实现那个看似简单的递推公式。记得当时用for循环处理一个长度1024的序列,看着进度条缓慢前进的样子,就像在等一杯手冲咖啡。这种顺序计算的瓶颈,正是现代状态空间模型需要突破的关键。
传统状态更新遵循严格的时序依赖:当前状态必须等待前一个状态计算完成。就像多米诺骨牌,我们只能耐心等待前一块倒下才能推倒下一块。这种计算模式在数学上表示为xₖ = A xₖ₋₁ + B uₖ,其中每个xₖ都牢牢依赖于xₖ₋₁。当序列长度达到数千甚至数万时(比如处理长文档或基因序列),这种顺序计算就成了性能杀手。
并行扫描算法打破了这种线性枷锁。它的精妙之处在于发现:当所有输入确定时,状态更新实际上存在潜在的并行性。就像突然发现多米诺骨牌可以分组同时推倒——只要预先计算好各组之间的传递关系。这种思想最早由Blelloch提出,现在通过pscan函数在Mamba等模型中焕发新生。
2. Blelloch算法的双重舞蹈:理解并行扫描的核心机制
2.1 Up-sweep阶段:自底向上的部分和聚合
想象你在统计公司各部门的年度开支。传统做法是财务一个个部门询问记录(顺序计算),而并行扫描则像同时让所有部门先统计自己的开支,然后相邻部门两两合并结果。这个向上汇总的过程就是up-sweep。
用代码来看更直观。假设输入是长度为8的序列:
X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
for k in range(int(math.log2(len(X)))):
X[1::2] += X[::2] # 奇数位加前一个偶数位
经过log₂N步后,我们得到了一个神奇的结果:原始数组中的某些位置已经包含了不同范围的部分和。就像搭积木时,某些关键节点已经预装了组合好的模块。
2.2 Down-sweep阶段:自上而下的前缀和传播
现在进入更精妙的阶段。up-sweep完


1444

被折叠的 条评论
为什么被折叠?



