samyyc.dev

点击登入

.. (UTC+8)

-- BPM

Fri Jul 12 2024

从代码层面理解Tacotron中CBHG模块

#112

/

samyyc

/

1114

/

3日 前更新

xs = xs.transpose(1, 2) + convs

这边进行了transpose,把维度变为原来的 (batch,Tmax,idim)(batch, Tmax, idim) ,然后进行残差连接

HighwayNet层

这一部分主要介绍HighwayNet是什么

HighwayNet最主要的功能是跳过没有用的层,加快信息传递。

根据HighwayNet的原论文,作者将传统普通神经网络的非线性变换描述成: y = H(x,WH)H(x, W_H),其中H抽象的代表非线性变换, W_H代表变换中的参数

HighwayNet添加了两个非线性变换: T(x,WT)T(x, W_T)C(x,WC)C(x, W_C) , T 被称为transform gate,C被称为carry gate,对于这两个名称的解释可以从下面这个式子体现:

y = H(x,WH)T(x,WT)+xC(x,WC)H(x, W_H) \cdot T(x, W_T) + x \cdot C(x, W_C)

可以看到,transform gate控制了这个非线性变换层原来的输出,而carry gate控制了原来的输入,换句话说,transform gate控制有多少变换应该被输出,carry gate控制有多少输入应该被直接携带到输出

在原论文中,作者把 C 设为 1-T ,原式也就变成:

y = H(x,WH)T(x,WT)+x(1T(x,WT))H(x, W_H) \cdot T(x, W_T) + x \cdot (1 - T(x, W_T))

再来看一下代码实现:

class HighwayNet(torch.nn.Module):
    def __init__(self, idim):
        super(HighwayNet, self).__init__()
        self.idim = idim
        self.H = nn.Linear(idim, idim)
        self.T = nn.Linear(idim, idim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: (batch, Tmax, idim)
        H = self.H(x)
        T = self.sigmoid(self.T(x))
        C = 1 - T
        y = H * T + x * C
        return y