@staticmethod
def _forward_rnn(cell, input, masks, initial, drop_masks):
max_time = input.size(0) # seq_len:41
output = []
hx = initial # ([32,200], [32,200]) 初始化值全为0
for time in range(max_time):
h_next, c_next = cell(input=input[time], hx=hx) # input[time]为[32,100](为batch里面一个位置上所以词), 经过一个lstmcell输出的h_n和c_n 都为(32,200)
h_next = h_next*masks[time] + initial[0]*(1-masks[time]) # masks(41,32,200),masks[time]为一个(32,200)值为0或者1,
c_next = c_next*masks[time] + initial[1]*(1-masks[time]) # 这里后面半句不应该一直都是0吗?
output.append(h_next) # 0-40每个位置上的输出
if drop_masks is not None: h_next = h_next * drop_masks
hx = (h_next, c_next) # 把上一个h_n,c_n作为参数传入下一个lstmcell
output = torch.stack(output, 0) # 把列表连接成(41,32,200)
return output, hx # 返回的结果是每一列输出h_n连接成的(41,32,200),和hx最后一个lstmcell输出的(h_next, c_next)
@staticmethod
def _forward
pytorch 重写lstm 使用mask
最新推荐文章于 2025-01-02 15:13:44 发布