BERT 等大模型性能强大,但很难部署到算力、内存有限的设备中。为此,来自华中科技大学、华为诺亚方舟实验室的研究者提出了 TinyBert,这是为基于 transformer 的模型专门设计的知识蒸馏(knowledge distillation,KD)方法。通过这种新的 KD 方法,大型 teacherBERT 模型中编码的大量知识可以很好地迁移到小型 student TinyBert模型中。模型大小还不到 BERT 的 1/7,但速度是 BERT 的 9 倍还要多,而且性能没有出现明显下降。
TinyBert 的结构如下图:
在TinyBert中,student 和 teacher 网络都是通过 Transformer 层构建的。
此外,研究者还提出了一种专门用于 TinyBERT 的两段式学习框架,从而分别在预训练和针对特定任务的学习阶段执行 transformer 蒸馏。这一框架确保 TinyBert 可以获取 teacherBERT 的通用知识和针对特定任务的知识。
除了提出新的 transformer 蒸馏法之外,研究者还提出了一种专门用于 TinyBERT 的两段式学习框架,从而分别在预训练和针对特定任务的具体学习阶段执行 transformer 蒸馏。这一框架确保 TinyBERT 可以获取 teacherBERT 的通用和针对特定任务的知识。
BERT模型的瘦身方法
1) 网络剪枝:
网络剪枝包括从模型中删除一部分不太重要的权重从而产生稀疏的权重矩阵,或者直接去掉与注意力头相对应的矩阵等方法来实现模型的剪枝,还有一些模型通过正则化方法实现剪枝。
2) 低秩分解:
即将原来大的权重矩阵分解多个低秩的小矩阵从而减少了运算量。这种方法既可以用于词向量以节省磁盘内存,也可以用到前馈层或自注意力层的参数矩阵中以加快模型训练速度。
3) 知识蒸馏
通过引入教师网络用以诱导学生网络的训练,实现知识迁移。教师网络拥有复杂的结构用以训练出推理性能优越的概率分布,是把概率分布这部分精华从复杂结构中“蒸馏”出来,再用其指导精简的学生网络的训练,从而实现模型压缩,即所谓知识蒸馏。蒸馏简单的说是将大模型(teacher)的学习结果,作为小模型(student)的学习目标,这样将大模型学到的知识迁移到另一个轻量级单模型上。
teacher和student模型原理甚至可以毫不相关,它的work原理,一方面student模型的loss构造学习了本身的true-label,也学到了teacher model的soft label, soft label本身也相对精确的模型(teacher)是数据泛化的一种结果,例如在二分类中,true label是【伤痛欲绝】,teacher大模型学到的是【一点点忧郁】,那【一点点忧郁】作为soft-label 也是student的学习目标,对于student只用泛化到不是【开心】就足够了。
4) 参数共享
ALBERT模型是BERT模型的改进版,其改进之一就是参数共享。全连接层与自注意力层都实现参数共享,即共享了编码器中的所有参数,这样不仅减少了参数量还提升了训练速度。
5) 量化
通过减少每个参数所需的比特数来压缩原始网络,可以显著降低内存需求。
6) 预训练和Downstream
模型压缩可以在模型训练时进行也可以在模型训练好之后进行。后期压缩使得训练更快,通常不需要训练数据,而训练期间压缩可以保持更高的准确性并导致更高的压缩率。
TinyBert瘦身方法
Tinybert主要用到的方法是模型的蒸馏。
知识蒸馏使用的是老师-学生(Teacher-Student)[1]模型,其中老师模型是“知识”的输出者,学生模型是“知识”的接受者。知识蒸馏的过程分为2个阶段:
- 原始模型训练
老师模型(Net-T)的特点是模型相对复杂,也可以由多个分别训练的模型集成而成。我们对老师模型不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是对于输入X, 其都能输出Y。其中Y经过Softmax函数的映射,输出值对应相应类别的概率值。
- 模型蒸馏
学生模型(Net-S)的特点是参数量较小、模型结构相对简单的单模型。同样地,对于输入X,其都能输出Y,Y经过Softmax函数映射输出对应相应类别的概率值。
在知识蒸馏的论文中,作者将问题限定在分类问题下,或者其他本质上属于分类问题的问题,该类问题的共同点是模型最后会有一个Softmax函数,其输出值对应了相应类别的概率值。
回到机器学习最基础的理论,机器学习最根本的目的是训练出在某个问题上泛化能力强的模型。即在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。
而现实中,由于我们不可能收集到某问题的所有数据作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。
而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。
一个很直白且高效的迁移泛化能力的方法就是:使用Softmax层输出的类别的概率来作为“soft target”(软标签)。
如下图11.18所示,传统机器学习模型在训练过程中拟合的标签为硬标签(Hard targets),即对真实类别的标签取独热编码并求极大似然,而知识蒸馏的训练过程则是使用了软标签,用大模型的的各个类别预测的概率作为软标签(Soft targets)。
图11.18 硬标签与软标签
这是由于大模型softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量远远大于传统的训练方式,通过软标签的学习可以让大模型教会小模型如何去学习。
而这个构造软标签的过程,涉及到知识蒸馏一个非常经典的概念,蒸馏温度。在介绍蒸馏温度之前,我们回顾一下Softmax公式。
但如果直接使用Softmax层的输出值作为软标签, 这又会带来一个问题: 当softmax输出的概率分布熵相对较小时,负标签的值都很接近0,对损失函数的贡献非常小,小到可以忽略不计。因此"蒸馏温度"这个变量就派上了用场,如下式11.50所示。
T即蒸馏温度。当T=1时,该式即是正常的Softmax公司。随着T越变高,softmax的输出概率结果也会越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将能关注到负标签的信息。其中红色柱为真实标签的类别概率,蓝色柱为负标签的类别概率。
图11.19引入蒸馏温度软标签的变化
通用的知识蒸馏框架图如图11.20所示。训练Net-T的过程即我们正常任务使用大模型完成当前的任务,下面详细讲讲第二步:高温蒸馏的过程。高温蒸馏过程的目标函数由distill loss(对应soft target)和student loss(对应hard target)加权得到,其表达式如下式11.51所示。
图11.20 知识蒸馏通用框架
为何第二部分Loss 仍引入硬标签呢?这是因为教师网络也有一定的错误率,使用真实标签的one hot编码可以有效降低错误被传播给学生网络的可能。打个比方,老师虽然学识远远超过学生,但是他仍然有出错的可能,而如果学生可以同时参考到标准答案,就可以有效地降低被老师偶尔的错误“带偏”的可能性。