1.ConcatTable:对每个成员模块应用相同输入。
如图:
+-----------+
+----> {member1, |
+-------+ | | |
| input +----+----> member2, |
+-------+ | | |
or +----> member3} |
{input} +-----------+
示例:
mlp = nn.ConcatTable() mlp:add(nn.Linear(5, 2)) mlp:add(nn.Linear(5, 3)) pred = mlp:forward(torch.randn(5)) for i, k in ipairs(pred) do print(i, k) end
2.ParallelTable:对每个成员模块应用与之对应的输入(第i个模块应用第i个输入)。
如图:
+----------+ +-----------+
| {input1, +---------> {member1, |
| | | |
| input2, +---------> member2, |
| | | |
| input3} +---------> member3} |
+----------+ +-----------+
mlp = nn.ParallelTable() mlp:add(nn.Linear(10, 2)) mlp:add(nn.Linear(5, 3)) x = torch.randn(10) y = torch.rand(5) pred = mlp:forward{x, y} for i, k in pairs(pred) do print(i, k) end
3.MapTable:对所有输入应用,不够的就clone。参数共享(weight
, bias
, gradWeight
and gradBias
)
eg:
+----------+ +-----------+
| {input1, +---------> {member, |
| | | |
| input2, +---------> clone, |
| | | |
| input3} +---------> clone} |
+----------+ +-----------+
map = nn.MapTable() map:add(nn.Linear(10, 3)) x1 = torch.rand(10) x2 = torch.rand(10) y = map:forward{x1, x2} for i, k in pairs(y) do print(i, k) end
4.SplitTable:这个不用解释,index可以为负数
module = SplitTable(dimension, nInputDims)