Fri Jul 12 2024
从代码层面理解Tacotron中CBHG模块
#112
/
samyyc
/
共 1114 字
/
3日 前更新
xs = xs.transpose(1, 2) + convs
这边进行了transpose,把维度变为原来的 ,然后进行残差连接
HighwayNet层
这一部分主要介绍HighwayNet是什么
HighwayNet最主要的功能是跳过没有用的层,加快信息传递。
根据HighwayNet的原论文,作者将传统普通神经网络的非线性变换描述成: y = ,其中H抽象的代表非线性变换, W_H代表变换中的参数
HighwayNet添加了两个非线性变换: 和 , T 被称为transform gate,C被称为carry gate,对于这两个名称的解释可以从下面这个式子体现:
y =
可以看到,transform gate控制了这个非线性变换层原来的输出,而carry gate控制了原来的输入,换句话说,transform gate控制有多少变换应该被输出,carry gate控制有多少输入应该被直接携带到输出
在原论文中,作者把 C 设为 1-T ,原式也就变成:
y =
再来看一下代码实现:
class HighwayNet(torch.nn.Module):
def __init__(self, idim):
super(HighwayNet, self).__init__()
self.idim = idim
self.H = nn.Linear(idim, idim)
self.T = nn.Linear(idim, idim)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: (batch, Tmax, idim)
H = self.H(x)
T = self.sigmoid(self.T(x))
C = 1 - T
y = H * T + x * C
return y
