人工智能如何克服遗忘困境?

479次阅读
没有评论

人工智能如何克服遗忘困境?

图片来源@视觉中国

文|追问NextQuestion

活到老,学到老,人类可以在不断变化的环境中连续自适应地学习——在新的环境中不断吸收新知识,并根据不同的环境灵活调整自己的行为。模仿碳基生命的这一特性,针对连续学习(continual learning,CL)的机器学习算法的研究应运而生,并成为大家日益关注的焦点。

那么,什么是连续学习?相较于传统单任务的机器学习方法,连续学习旨在学习一系列任务,即在连续的信息流中,从不断改变的概率分布中学习和记住多个任务,并随着时间的推移,不断学习新知识,同时保留之前学到的知识。

然而,这个领域的技术发展并非一帆风顺,面临着许多难题。《庄子·秋水》中曾描述过一个这样的故事:战国时期,燕国有一少年听闻赵国都城邯郸人走路姿势异常优美,心向往之。遗憾的是,他在跟随邯郸人学步数月后,却把之前走路姿势忘记了,最后甚至都不会走路了,无奈只好爬回了燕国。有趣的是,这则寓言故事深蕴着当前连续学习模型的困境之一——灾难性遗忘(catastrophic forgetting),模型在学习新任务之后,由于参数更新对模型的干扰,会忘记如何解决旧任务。而对于机器学习技术而言,另一普遍关注的概念便是泛化误差(generalization error),这是衡量机器学习模型泛化能力的标准,用以评估训练好的模型对未知数据预测的准确性。泛化误差越小,说明模型的泛化能力越好。

尽管目前很多实验研究致力于解决连续学习中的灾难性遗忘问题,但是对连续学习的理论研究还十分有限。哪些因素与灾难性遗忘和泛化误差相关?它们如何明确地影响模型的连续学习能力?对此我们所知甚少。

近期,来自美国俄亥俄州立大学Ness Shroff教授团队的研究工作“Theory on Forgetting and Generalization of Continual Learning”或有望为这一问题提供详细的解答。他们从理论上解释了过度参数化(over parameterization)、任务相似性(task similarity)和任务排序(task ordering)对遗忘和泛化误差的影响,发现更多的模型参数、更低的噪声水平、更大的相邻任务间差异,有助于降低遗忘。同时,通过深度神经网络(DNN),他们在真实数据集上验证了该理论的可行性。

人工智能如何克服遗忘困境?

图注:论文封面,该论文于2023年2月刊登在ArXiv上

人工智能如何克服遗忘困境?

连续学习线性模型的构建

在经典的机器学习理论中,参数越多,模型越复杂,往往会带来不期望见到的过拟合。但以DNN为代表的深度学习模型则不然,其参数越多,模型训练效果越好。为了理解这一现象,作者更加关注在过参数化的情况下(p>n),连续学习模型的表现。文章首次定义了基于过参数化线性模型的连续学习模型,考量其在灾难性遗忘和泛化误差问题上的闭合解(定理1.1)。

定理1.1  当p≥n+2时,则:

人工智能如何克服遗忘困境?

T={1,…,T}代表任务序列;||wi∗ – wj∗||2表征任务i和j之间的相似性;p为模型实际参数的数量;n为模型需要的参数数量;r为过参数化的比例,r=1-n/p;σ为噪声水平;ci,j =(1-r)(rT-i-rj-i+rT-j),其中1≤i≤j≤T更多参数介绍详看原始文献和附录部分。

(9)式和(10)式分别为灾难性遗忘FT和泛化误差GT的数学表示。它们不仅描述了连续学习在线性模型中是如何工作的,还为其在一些真实的数据集和DNN中的应用提供指导。

连续学习中的鼎足三分

在上述数学模型的基础上,作者还研究了在连续学习过程中,过参数化、任务之间的相似程度和任务的训练顺序三个因素对灾难性遗忘和泛化误差的影响。

1)过参数化

· 更多的模型训练参数将有助于降低遗忘

如定理1.1所示,当表示参数数量的p趋近于0时,E[FT]也将趋近于零。

· 噪声水平和(或)任务间相似度低的情况下,过参数化更好

为了比较过参数化和欠参数化时模型的性能,作者构建了与定理1.1类似的,在欠参数情况下的理论模型定理1.2。

定理1.2  当n≥p+2时,则:

人工智能如何克服遗忘困境?

如定理1.2所示,欠参数化的情况下,当噪声水平σ较大时,以及当训练的任务间区分度较大时,E[FT]和E[GT]都变大。相反,过参数化的情况下,当噪声水平σ较大时,以及当训练的任务间不太相似时,E[FT]和E[GT]都变小。这表明当噪声水平高和(或)训练任务相似性较低时,过参数化的情况可能比欠参数化的情况训练效果更好,即存在良性过拟合。

2)连续训练任务的相似性

· 泛化误差随着任务相似性的增加而降低,而遗忘则可能不会随之降低

如定理1.1所示,由于公式(10)中G2项的系数始终为正,所以当任务之间越相似,区分度越少时,泛化误差会相应降低。但是由于公式(9)中,F2项的系数并不总是为正,所以可能出现任务之间的相似性增加模型的遗忘性能也增加的情况。

3)任务训练顺序

· 在早期阶段将差异大的任务相邻训练,将有助于降低遗忘

为了找到连续学习中,任务的最优训练顺序。作者考虑了两种特殊情况。情况一,任务集由一个特殊的任务,和剩余其它完全一模一样的任务组成。情况二,任务集由数目相同的不同任务组成。通过对两种情况的比较分析得出:

首先,特殊的任务在训练时,应优先在前半段执行;

其次,相邻任务之间应差异较大;这些措施都将有助于降低连续学习模型的遗忘。但是,最小化的遗忘和最小化的泛化误差的最佳任务训练排序有时并不相同。

DNN对连续学习模型的验证

最后,为了验证上述推论的可靠性,作者使用DNN在真实数据集上进行实验。后续的实验结果明确地证实了,任务相似性对连续学习模型灾难性遗忘的非单调性影响。而关于任务排序影响的实验结果也与前面线性模型中的发现一致,即应在模型训练早期设置区分度较大的任务学习,并安排区分度较大任务相邻训练。

表1:使用TRGP和TRGP+两种任务策略在不同数据集中训练得到的准确性和反向迁移(用负值表示遗忘;值越大/正,表示知识反向迁移效果越好)结果

人工智能如何克服遗忘困境?

正向迁移:在学习新任务的过程中,利用以前的任务中学习到的经验来帮助新任务的知识学习。

反向迁移:在学习新任务的过程中,学习到的新知识,巩固了以前任务的知识学习。

PMNIST数据集:MNIST数据集是机器学习模型训练所使用的经典数据集,包含0-9这10个数字的手写样本,其中每个样本的输入是一个图像,标签是图像所代表的数字。PMNIST是基于MNIST数据集的变种,由10种不同的MNIST样本置换顺序的连续学习任务组成,可进行连续学习问题的评估。Split CIFAR-100数据集:CIFAR-100数据集也是机器学习模型训练所使用的经典数据集,包含100种分类任务,如蜜蜂、蝴蝶等。每类有600张彩色图像,其中500张作为训练集,100张作为测试集。同样,为了在该数据集上进行连续学习问题的评估,作者将CIFAR-100数据集等分为10组,每一组由10个完全不同的分类任务组成,重构了Split CIFAR-100连续学习数据集。

更有趣的是,作者发现,相较于赋以不同时间点学习的旧任务相同的权重(TRGP)的策略,赋以最近学习的旧任务更多的权重(TRGP+),可以更好地促进连续学习模型的知识正向迁移和反向迁移(表 1)。这些发现有望为后续连续学习策略的设计提供理论参考。

参考链接:

  • Lin, S., Ju, P., Liang, Y., & Shroff, N. (2023). Theory on Forgetting and Generalization of Continual Learning. ArXiv. /abs/2302.05836
  • 韩亚楠, & Liu, Jianwei & Luo, Xiong-Lin. (2021). 连续学习研究进展. Journal of Computer Research and Development. 10.7544/issn1000-1239.2022.20201058.

更多精彩内容,关注钛媒体微信号(ID:taimeiti),或者下载钛媒体App

Read More 

正文完
可以使用微信扫码关注公众号(ID:xzluomor)
post-qrcode
 
评论(没有评论)
Generated by Feedzy