一、前置知识 §
score / informant 定义为对数似然函数关于参数的梯度:
s ( θ ) ≡ ∂ θ ∂ log L ( θ )
其中L ( θ ) 即为似然函数,可扩写为L ( θ ∣ x ) ,其中x 为观测到的数据,x 从采样域X 中产生
在某一特定点的 s 函数指明了该点处对数似然函数的陡峭程度(steepness),或者是函数值对参数发生无穷小量变化的敏感性。
如果对数似然函数定义在连续实数值的参数空间中,那么它的函数值将在(局部)极大值与极小值点消失。这一性质通常用于极大似然估计中(maximum likelihood estimation, MLE),来寻找使得似然函数值极大的参数值。
注意L ( θ ∣ x ) 中竖线前后的字母θ ∣ x ,x 为随机变量,在这里则是一个定值,意为采样后的观测值 ,而θ 则为自变量,意为参数模型中的参数
当(假设 )θ 位于正确值时,我们可以通过θ 推导x ,也就是f ( x ∣ θ ) ,为一概率密度函数,意为当模型参数为θ 时,采样到x 的概率
从两个角度得到了对同一事实的论证,因此可写作f ( x ∣ θ ) = L ( θ ∣ x )
首先,来分析s 的数学期望,这里讨论的问题是:当参数取值为θ 时,s ∣ θ 的数学期望
从直观上分析,当参数位于真实 (最佳 )参数点时,似然函数有其极大值(考虑极大似然估计的定义),因此为一极值点,所以该点梯度为0 ,即E [ s ∣ θ ] = 0
下面进行公式分析:
首先要明确,该期望是s 函数关于什么随机变量 的期望。从上面的讨论中可以得到,该问题中唯一的随机变量是采样观测值x ,它的采样概率是f ( x ∣ θ )
我们有
f ( x ) ∂ x ∂ log f ( x ) = f ( x ) f ( x ) 1 ∂ x ∂ f ( x ) = ∂ x ∂ f ( x )
因此
E [ s ∣ θ ] = ∫ X f ( x ∣ θ ) ⋅ s ⋅ d x = ∫ X f ( x ∣ θ ) ∂ θ ∂ log L ( θ ∣ x ) d x = ∫ X f ( x ∣ θ ) ∂ θ ∂ log f ( x ∣ θ ) d x = ∫ X ∂ θ ∂ f ( x ∣ θ ) d x = ∂ x ∂ ∫ X f ( x ∣ θ ) d x = ∂ x ∂ 1 = 0 ■
因此得证:E [ s ∣ θ ] = 0
2. Fisher信息矩阵 §
Fisher信息(Fisher information),或简称为信息(information)是一种衡量信息量的指标
假设我们想要建模一个随机变量 x 的分布,用于建模的参数是 θ ,那么Fisher信息测量了x 携带的对于 θ 的信息量
所以,当我们固定 θ 值,以 x 为自变量,Fisher 信息应当指出这一 x 值可贡献给 θ 多少信息量
比如说,某一 θ 点附近的函数平面非常陡峭(有一极值峰值),那么我们不需要采样多少 x 即可做出比较好的估计,也就是采样点 x 的Fisher 信息量较高。反之,若某一 θ 附近的函数平面连续且平缓,那么我们需要采样很多点才能做出比较好的估计,也就是 Fisher 信息量较低。
从这一直观定义出发,我们可以联想到随机变量的方差,因此对于一个(假设的)真实参数 θ , s 函数的 Fisher 信息定义为 s 函数的方差
I ( θ ) = E [ ( ∂ θ ∂ log f ( x ∣ θ ) ) 2 θ ] = ∫ ( ∂ θ ∂ log f ( x ∣ θ ) ) 2 f ( x ; θ ) d x
此外,如果 log f ( x ∣ θ ) 对于 θ 二次可微,那么 Fisher 信息还可以写作
I ( θ ) = − E [ ∂ 2 θ ∂ 2 log f ( x ∣ θ ) θ ]
证明如下:
∵ 0 ∴ 0 A B = E [ s ∣ θ ] = ∂ θ ∂ E [ s ∣ θ ] = ∂ θ ∂ ∫ X f ( x ∣ θ ) ∂ θ ∂ log L ( θ ∣ x ) d x = ∫ X ∂ θ ∂ ∂ θ ∂ log L ( θ ∣ x ) f ( x ∣ θ ) d x ▹ use chain rule = ∫ X { ∂ 2 θ ∂ 2 log L ( θ ∣ x ) f ( x ∣ θ ) + ∂ θ ∂ f ( x ∣ θ ) ∂ θ ∂ log L ( θ ∣ x ) } d x = A ∫ X ∂ 2 θ ∂ 2 log L ( θ ∣ x ) f ( x ∣ θ ) d x + B ∫ X ∂ θ ∂ L ( θ ∣ x ) ∂ θ ∂ log L ( θ ∣ x ) d x = E [ ∂ 2 θ ∂ 2 log L ( θ ∣ x ) θ ] = ∫ X ∂ θ ∂ L ( θ ∣ x ) ∂ θ ∂ log L ( θ ∣ x ) d x = ∫ X ∂ θ ∂ log L ( θ ∣ x ) L ( θ ∣ x ) ∂ θ ∂ log L ( θ ∣ x ) d x = ∫ X ( ∂ θ ∂ log L ( θ ∣ x ) ) 2 f ( x ∣ θ ) d x = E [ ( ∂ θ ∂ log L ( θ ∣ x ) ) 2 θ ] ∵ A + B = 0 ∴ E [ ∂ 2 θ ∂ 2 log L ( θ ∣ x ) θ ] + E [ ( ∂ θ ∂ log L ( θ ∣ x ) ) 2 θ ] = 0
二、EWC §
1. 数学推导 §
假设数据集被划分为两个任务Σ = { A , B } ,网络参数为θ
学习任务为最大化后验概率
= = arg θ max P ( θ ∣Σ ) arg θ max log P ( θ ∣Σ ) arg θ min l ( θ )
其中l ( θ ) 定义为训练 loss
考虑任务训练顺序 A ⇒ B
log P ( θ ∣Σ ) = log P ( θ ∣ A , B ) = log P ( θ , A , B ) − log P ( A , B ) = log P ( B ∣ θ , A ) + log P ( θ , A ) − log P ( B ∣ A ) − log P ( A ) = log P ( B ∣ θ ) + log P ( θ ∣ A ) + log P ( A ) − log P ( B ) − log P ( A ) ▹ A, B i.i.d = loss on B log P ( B ∣ θ ) + unknown log P ( θ ∣ A ) − constant log P ( B )
其中后验概率 log P ( θ ∣ A ) 不易得到,因此使用拉普拉斯近似进行分析
在训练任务 B 之前,网络已经在任务 A 上收敛,设网络此时的参数为 θ A ∗ ,为在任务 A 上拟合得到的参数,设函数 f ( θ ) = log P ( θ ∣ A )
对 f ( θ ) 在 θ = θ A ∗ 做泰勒展开:
f ( θ ) = f ( θ A ∗ ) + = 0 ∂ θ ∂ f ( θ ) θ A ∗ ( θ − θ A ∗ ) + 2 1 ( θ − θ A ∗ ) T ∂ 2 θ ∂ 2 f ( θ ) θ A ∗ ( θ − θ A ∗ ) + ⋯ ≈ f ( θ A ∗ ) + 2 1 ( θ − θ A ∗ ) T ∂ 2 θ ∂ 2 f ( θ ) θ A ∗ ( θ − θ A ∗ )
将 f ( θ ) = log P ( θ ∣ A ) 代入:
log P ( θ ∣ A ) P ( θ ∣ A ) where: Δ ϵ = log P ( θ A ∗ ∣ A ) + 2 1 ( θ − θ A ∗ ) T ∂ 2 θ ∂ 2 log P ( θ ∣ A ) θ A ∗ ( θ − θ A ∗ ) = log P ( θ A ∗ ∣ A ) + 2 1 ( θ − θ A ∗ ) T ⎩ ⎨ ⎧ − [ − ∂ 2 θ ∂ 2 log P ( θ ∣ A ) θ A ∗ ] − 1 ⎭ ⎬ ⎫ − 1 ( θ − θ A ∗ ) = exp Δ + 2 1 ( θ − θ A ∗ ) T ⎩ ⎨ ⎧ − [ − ∂ 2 θ ∂ 2 log P ( θ ∣ A ) θ A ∗ ] − 1 ⎭ ⎬ ⎫ − 1 ( θ − θ A ∗ ) = ϵ exp − 2 1 ( θ − θ A ∗ ) T Σ − 1 ⎩ ⎨ ⎧ [ − ∂ 2 θ ∂ 2 log P ( θ ∣ A ) θ A ∗ ] − 1 ⎭ ⎬ ⎫ − 1 ( θ − θ A ∗ ) = log P ( θ A ∗ ∣ A ) = exp Δ
观察形式可得:
P ( θ ∣ A ) ∼ N θ A ∗ , ( − ∂ 2 θ ∂ 2 log P ( θ ∣ A ) θ A ∗ ) − 1
其中协方差矩阵项正是第一部分讨论的Fisher信息矩阵,记做I A ,则有
P ( θ ∣ A ) ∼ N ( θ A ∗ , [ I A ] − 1 )
另外,EWC是以一个参数的视角出发的,因此Fisher信息矩阵只需要对角线元素,其余计算出来的结果可以置0,所以有:
P ( θ ∣ A ) log P ( θ i ∣ A ) = ( 2 π ) k ∣Σ∣ 1 exp { − 2 1 ( θ − θ A ∗ ) T Σ − 1 ( θ − θ A ∗ )} = − 2 1 ( θ i − [ θ i ] A ∗ ) 2 ∗ [ Σ − 1 ] ii = − [ I A ] ii 2 ( θ i − [ θ i ] A ∗ ) 2
所以,所有参数的EWC Loss可定义为:
l EWC = − i = 1 ∑ #Params [ I A ] ii 2 ( θ i − [ θ i ] A ∗ ) 2
将上述内容代入总优化目标:
l ( θ ) = log P ( θ ∣Σ ) = log P ( B ∣ θ ) + log P ( θ ∣ A ) − log P ( B ) ⇒ l CE ( θ ∣ B ) + l EWC ( θ ∣ A )
定义超参数 λ 进行稳定性-可塑性权衡
l ( θ ) = l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A )
因此优化目标为:
arg θ min l ( θ ) = arg θ min { l CE ( θ ∣ B ) + λ ⋅ l EWC ( θ ∣ A ) } = arg θ min { l CE ( θ ∣ B ) − 2 λ i = 1 ∑ #Params [ I A ] ii 2 ( θ i − [ θ i ] A ∗ ) 2 }
2. 如何计算 Fisher 信息矩阵 §
将训练过程进行划分:
使用数据集 A 与 l CE ( θ ∣ A ) 训练模型
保存此时的参数,即θ A ∗ ,并计算 Fisher 信息矩阵 I A
使用数据集 B 、l CE ( θ ∣ B ) 与 l EWC ( θ ∣ A ) 训练模型
多任务{ A , B , C , … } 同理
最后一个问题,如何使用 A 计算 I A
考虑定义
I ( θ ) = E [ ( ∂ θ ∂ log f ( x ∣ θ ) ) 2 θ ]
可以通过计算梯度的平方来获得每一个参数的 Fisher 信息矩阵项:
I ( θ ) = N 1 ( x , y ) i ∈ A ∑ Gradient ∂ θ ∂ l LL ( θ ∣ ( x , y ) i ) 2
具体来说,可以向模型逐个喂入样本,并计算损失函数,使用神经网络框架自动计算梯度。对于每个参数,累加所有的梯度,最后除以样本数量,即可得到对应参数的 Fisher 信息矩阵项
需要注意的是,当使用 nn.CrossEntrypyLoss
或 nn.NLLLoss
时,由于其中对于 Log-Likelihood 使用了相反数处理,使用该类损失函数得到的矩阵是真实 Fisher 信息矩阵的相反数,即I ^ ( θ ) = − I ( θ ) ,在计算 loss 时要记得将减号改成加号,即
arg θ min l ( θ ) = arg θ min { l CE ( θ ∣ B ) + 2 λ i = 1 ∑ #Params [ I A ] ^ ii 2 ( θ i − [ θ i ] A ∗ ) 2 }