跳转至

Mean Flow

文档信息

创建时间:2025-12-19 | 更新时间:2025-12-19

本文基于Mean Flows for One-step Generative Modeling 做笔记

主要贡献

从平均速度的角度来蒸馏模型,减少了离散化误差,本质都是多步合成一步,但是本文以平均速度为训练目标,中间没有数学上的误差,且最后采样可能更加平滑。

核心公式

image-20251219114453744

其中 \(z_t\) 为 t 时刻隐状态, \(u(z_t,r,t)\) 表示从 r 时刻到 t 时刻的平均速度, \(v_{(z_t,t)}\) 表示教师给出的瞬时速度

重新排列得到平均速度的恒等式

image-20251219114922088

现在计算平均速度对时间的全导数(注意 r 是和 t 无关的参数),得到

image-20251219115006242

因为这三个变量对t求导可以解析,或者由模型直接给出,所以接下来可以通过jvp直接得出全导数

Note

公式 (7) 本质上描述了多变量函数沿特定方向的演化。根据链式法则:

\[\frac{d}{dt}u(z_t, r, t) = \nabla u \cdot \mathbf{v}\]

其中 \(\mathbf{v} = [\frac{dz_t}{dt}, \frac{dr}{dt}, \frac{dt}{dt}]^T\) 是变量随时间演化的速度向量。

  • JVP 的定义:给定函数 \(f(\mathbf{x})\) 和向量 \(\mathbf{v}\),JVP 计算的是 \(J \mathbf{v}\),即函数在 \(\mathbf{x}\) 处沿方向 \(\mathbf{v}\) 的导数。
  • 对应关系:在公式 (8) 中,已经解析地知道 \(\frac{dz_t}{dt} = v(z_t, t)\)\(\frac{dr}{dt} = 0\),以及 \(\frac{dt}{dt} = 1\)。这三个值正好构成了切向量 \(\mathbf{v} = [v, 0, 1]\)

Tip

JVP做的是前向推导,给出输入的变化v,计算输出的变化

VJP做的是反向传播,给出loss对本层输出的偏导,反向计算对本层输入的偏导

目标函数写作

image-20251219115331543

对应伪代码,其中 fn 是神经网络预测 u 的j

image-20251219142822938