type
status
date
slug
summary
tags
category
icon
password
Property
长短期记忆网络的设计比门控循环单元稍微复杂一些, 却比门控循环单元早诞生了近20年
门控记忆元
长短期记忆网络的设计灵感来自于计算机的逻辑门。 长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。 有些文献认为记忆元是隐状态的一种特殊类型, 它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。 为了控制记忆元,需要许多门。
- 一个门用来从单元中输出条目,称为输出门(output gate)
- 一个门用来决定何时将数据读入单元,称为输入门(input gate)
- 一种机制来重置单元的内容,由遗忘门(forget gate)来管理, 这种设计的动机与门控循环单元相同, 能够通过专用机制决定什么时候记忆或忽略隐状态中的输入
输入门、忘记门和输出门
就如在门控循环单元中一样, 当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中。 它们由三个具有
sigmoid
激活函数的全连接层处理, 以计算输入门、遗忘门和输出门的值。 因此,这三个门的值都在 的范围内。假设有 个隐藏单元,批量大小为,输入数为。 因此,输入为, 前一时间步的隐状态为。 相应地,时间步的门被定义如下: 输入门是, 遗忘门是, 输出门是
和是权重参数, 是偏置参数
候选记忆元
候选记忆元(candidate memory cell)。 它的计算与上面描述的三个门的计算类似, 但是使用函数作为激活函数,函数的值范围为。 在时间步 处的方程:
和是权重参数, 是偏置参数
记忆元
在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。 类似地,在长短期记忆网络中,也有两个门用于这样的目的: 输入门控制采用多少来自的新数据, 而遗忘门控制保留多少过去的记忆元 的内容。 使用按元素乘法,得出:
如果遗忘门始终为且输入门始终为, 则过去的记忆元 将随时间被保存并传递到当前时间步。 引入这种设计是为了缓解梯度消失问题, 并更好地捕获序列中的长距离依赖关系。
这样就得到了计算记忆元的流程图
隐状态
最后,定义如何计算隐状态, 这就是输出门发挥作用的地方。 在长短期记忆网络中,它仅仅是记忆元的 的门控版本。 这就确保了 的值始终在区间内:
只要输出门接近,就能够有效地将所有记忆信息传递给预测部分, 而对于输出门接近,只保留记忆元内的所有信息,而不需要更新隐状态。
长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。 多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。 然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型的成本是相当高的。 在后面会有更高级的替代模型如
transformer
。从零开始实现
简洁实现
LSTM前向与反向传播算法
从RNN到LSTM
RNN
具有如下的结构,每个序列索引位置t都有一个隐藏状态 如果略去每层都有的 ,则
RNN
的模型可以简化成如下图的形式:隐藏状态 由 和 得到,得到 后一方面用于当前层的模型损失计算,另一方面用于计算下一层的 。由于
RNN
梯度消失的问题,大牛们对于序列索引位置的隐藏结构做了改进,可以说通过一些技巧让隐藏结构复杂了起来,来避免梯度消失的问题,这样的特殊RNN就是LSTM。LSTM有很多的变种,这里以最常见的LSTM为例:LSTM模型结构剖析
在每个序列索引位置时刻向前传播的除了和RNN一样的隐藏状态 ,还多了另一个隐藏状态,如图中上面的长横线。这个隐藏状态一般称为细胞状态(Cell State),记为 :
除了细胞状态,LSTM图中还有了很多奇怪的结构,这些结构一般称之为门控结构(Gate)。LSTM在在每个序列索引位置t的门一般包括遗忘门,输入门和输出门三种。
LSTM之遗忘门
遗忘门是控制是否遗忘的,在LSTM中即以一定的概率控制是否遗忘上一层的隐藏细胞状态:
输入的有上一序列的隐藏状态 和本序列数据 ,通过一个激活函数,一般是
sigmoid
,得到遗忘门的输出 。由于sigmoid
的输出 在[0,1]之间,因此这里的输出代表了遗忘上一层隐藏细胞状态的概率。数学表达式为: 其中为线性关系的系数和偏倚,和RNN中的类似。 为
sigmoid
激活函数。LSTM之输入门
输入门负责处理当前序列位置的输入:
输入门由两部分组成,第一部分使用了
sigmoid
激活函数,输出为 ,第二部分使用了tanh
激活函数,输出为 , 两者的结果后面会相乘再去更新细胞状态。数学表达式为:其中 为线性关系的系数和偏倚,和RNN中的类似。 为
sigmoid
激活函数。LSTM之细胞状态更新
前面的遗忘门和输入门的结果都会作用于细胞状态:
细胞状态 由两部分组成,第一部分是 和遗忘门输出 的乘积,第二部分是输入门的 和 的乘积,即:
为Hadamard积
LSTM之输出门
有了新的隐藏细胞状态 ,就可以来看输出门了:
隐藏状态的更新由两部分组成,第一部分是 , 它由上一序列的隐藏状态和本序列数据 ,以及激活函数
sigmoid
得到,第二部分由隐藏状态和tanh
激活函数组成:当然,有些LSTM的结构和上面的LSTM图稍有不同,但是原理是完全一样的。
LSTM前向传播算法
LSTM模型有两个隐藏状态 ,模型参数几乎是RNN的4倍,多了这些参数。
前向传播过程在每个序列索引位置的过程为:
- 更新遗忘门输出:
- 更新输入门两部分输出:
- 更新细胞状态:
- 更新输出门输出:
- 更新当前序列索引预测输出:
LSTM反向传播算法推导关键点
思路和RNN的反向传播算法思路一致,也是通过梯度下降法迭代更新所有的参数,关键点在于计算所有参数基于损失函数的偏导数。
在RNN中,为了反向传播误差,通过隐藏状态 的梯度 一步步向前传播。在LSTM这里也类似。只不过这里有两个隐藏状态 和 。这里定义两个 ,即:
为了便于推导,将损失函数 分成两块,一块是时刻 位置的损失 ,另一块是时刻之后损失:
而在最后的序列索引位置 的 $和 为:
接着由 反向推导
的梯度由本层 时刻的输出梯度误差和大于t时刻的误差两部分决定:
整个LSTM反向传播的难点就在于这部分的计算。仔细观察,由于 , 在第一项 中,包含一个 的递推关系,第二项 就复杂了, 函数里面又可以表示成:
函数的第一项中, 包含一个 的递推关系,在 函数的第二项中, 和 都包含$h$的递推关系,因此,最终 这部分的计算结果由四部分组成。即:
而 的反向梯度误差由前一层 的梯度误差和本层的从 传回来的梯度误差两部分组成:
有了 和 , 计算这一大堆参数的梯度就很容易了,这里只给出 的梯度计算过程,其他的 的梯度只要照搬就可以了。