PyTorch_構建一個LSTM網路單元

2020-10-29 10:00:02

今天用PyTorch參考《Python深度學習基於PyTorch》搭建了一個LSTM網路單元,在這裡做一下筆記。

1.LSTM的原理

LSTM是RNN(迴圈神經網路)的變體,全名為長短期記憶網路(Long Short Term Memory networks)。
它的精髓在於引入了細胞狀態這樣一個概念,不同於RNN只考慮最近的狀態,LSTM的細胞狀態會決定哪些狀態應該被留下來,哪些狀態應該被遺忘。
具體與RNN的區別可參考這篇博文:LSTM與RNN的比較
先放一張LSTM網路的模型圖:

在這裡插入圖片描述
如上圖所示,可以看到這是一個網路,我們單拿出其中一個單元來進行分析,可見每一個單元都包含一系列運算,那麼這些運算的意義是什麼呢?下面我們來一一解釋每個單元的具體內容。

(1)遺忘門
在這裡插入圖片描述
ht-1 :前一個時刻的Cell的輸出
xt : 當前時刻的輸入
注意:中括號的意思是將ht-1與xt拼接起來,後面出現公式同理

遺忘門主要來判斷上一狀態中的輸出對現狀態的影響大小,遺忘門的輸出要通過一個Sigmoid函數,Sigmoid函數的輸出範圍是0~1,相當於得到一個權重,後面與Ct-1相乘,以此得到上一狀態輸出對現狀態的影響。

(2)輸入門
在這裡插入圖片描述
輸入門中會得到一個臨界的細胞狀態(Ct^),表示此狀態下的備選輸出,與it作用後就得到此次狀態需要輸出的內容。

在這裡插入圖片描述
由以上兩個門就可以輸出更新後的細胞狀態Ct,輸出公式如上圖所示,需要注意這裡的「 * 」符號為哈達瑪乘積,就是對應矩陣元素相乘。

(3)輸出門
在這裡插入圖片描述
輸出門具體運算過程如上圖所示。這樣就得到了這個時刻的輸出,把這個輸出再傳入下一狀態即可。

2.程式碼實現

初始化:

import torch
import torch.nn as nn

搭建一個LSTM單元:

class LSTMCell(nn.Module):
    def __init__(self,input_size,hidden_size,cell_size,output_size):
        super(LSTMCell,self).__init__()
        self.hidden_size = hidden_size
        self.cell_size = cell_size
        #設定門輸入輸出資料的大小尺寸
        self.gate = nn.Linear(input_size+hidden_size,cell_size)
        self.output = nn.Linear(hidden_size,output_size)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        #分類器-輸出
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self,input,hidden,cell):
        #拼接資料,後置的0/1 確定橫向(1)還是豎向(0)拼接 
        combined = torch.cat((input,hidden),1)
        #根據LSTM一個單元的網路圖得出三個門,並進行運算
        f_gate = self.sigmoid(self.gate(combined))
        i_gate = self.sigmoid(self.gate(combined))
        #z_state看作為Cell的中間狀態
        z_state = self.tanh(self.gate(combined))
        o_gate = self.sigmoid(self.gate(combined))
        #注意這下面的乘為哈達瑪乘積,矩陣對應元素相乘
        cell = torch.add(torch.mul(f_gate,cell),torch.mul(i_gate,z_state))
        hidden = torch.mul(self.tanh(cell),o_gate)
        output = self.output(hidden)
        output = self.softmax(output)
        return output,hidden,cell
    
    def initHidden(self):
        return torch.zeros(1,self.hidden_size)
    
    def initCell(self):
        return torch.zeros(1,self.cell_size)

範例化LSTMCell,並傳入輸入、隱含狀態等進行驗證:

lstmcell = LSTMCell(input_size=10,hidden_size=20,cell_size=20,output_size=10)
input = torch.randn(32,10)
h_0 = torch.randn(32,20)
c_0 = torch.randn(32,20)
output,hn,cn = lstmcell(input,h_0,c_0)
print(output.size(),hn.size(),cn.size())

輸出結果:
torch.Size([32, 10]) torch.Size([32, 20]) torch.Size([32, 20])

end
(以上圖片來源於網路,若侵權請聯絡刪除)