转自:http://blog.youkuaiyun.com/hungryof/article/details/52027049
总说
上篇博客已经初步介绍了Module类。这里将更加仔细的介绍。并且还将介绍Container, Transfer Functions Layers和 Simple Layers模块。
Module
主要有4个函数。
1. [output]forward(input)
2. [gradInput] backward (input.gradOutput)
以及在后面自动挡训练方式中重载的两个函数
3. [output] updateOutput (input)
4. [gradInput] updateGradInput (input, gradOutput)
注意点:
1. forward函数的input必须和backward的函数的input一致!否则梯度更新会有问题。
2. Module的forward会调用updateOutput(input), 而backward会调用[gradInput] updateGradInput (input, gradOutput)和accGradParameters(input, gradOutput)
3. 高级训练方式只要重载updateOutput和updateGradInput这两个函数,内部参数会自动改变。
4. 对于第3点,需要更加深入的探讨
Container
复杂的神经网络可以用container类进行构建。
container的子类主要有三个:Sequential, Parallel, Concat 。他重新实现了Module类的方法。此外还增加了很多方法。
主要函数:
1. add(module)
2. get(index)
3. size()
return the number of contained modules.
4. remove(index)
5. insert (module, [index] ) 注意这里的index是插入后,其排到的index
用法就是,一般是用一个Sequential,然后不断add(module),而module有simple layers和卷积层。在后面进行说明。
Transfer Functions Layers
就是激活函数,在第上一篇博客已经说明了,torch中讲神经网络看成是module(或是container)的组合。你可以加入层模块,或是激活函数的层模块,最后还可以加上criterion层模块。如此一来,整个网络就构建好了。
这个没啥好说的。
里面有挺多激活函数的. SoftMax, SoftMin, SoftPlus, LogSigmoid, LogSoftMax, Sigmoid, Tanh, ReLU, PReLU, ELU, LeakyReLU等等。
这里就拿Tanh举例吧。
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">ii=torch.linspace(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>) m=nn.Tanh() oo=m:forward(ii) go=torch.ones(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">100</span>) gi=m:backward(ii,go) gnuplot.plot({<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'f(x)'</span>,ii,oo,<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'+-'</span>},{<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'df/dx'</span>,ii,gi,<span class="hljs-string" style="color: rgb(0, 136, 0); box-sizing: border-box;">'+-'</span>}) gnuplot.grid(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>
Simple Layers
简单层有很多,一些是提供仿射变换的,一些是进行Tensor method的。
- 具有参数的modules有:
Linear
Add : 对输入增加一个偏置项
Mul
CMul
等等 - 进行基本Tensor运算的
View
Transpose
等等 - 进行数学运算的
Max, Min, Exp, Mean, Log, Abs, MM:matrix-matrix multiplication.
Normalize: normalize the input to have unit L-p norm - 其他
Identity
Dropout
下面调重要常见的几个
Linear
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Linear(inputDim, outputDim, [bias = <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>])</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
Linear就是全连接呗。
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Linear(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>,<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>) mlp = nn.Sequential() mlp:add(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>) <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.weight) <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.bias) <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.gradWeight) <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(<span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>.gradBias) x = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span>) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">-- 10 inputs</span> y = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li></ul>
运行后发现,只要初始化后网络,里面就有初始权值和偏置。但是gradBias和gradGradient不是很大就是很小,显然这是垃圾数据。现在有个问题,网络权值的初始化对整个网络的训练非常重要,torch是怎样自定义初始化权值的呢?我现在还没发现!先做个标记!
Dropout
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Dropout(p)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
这个就是Dropout层。简单的说每一个神经元的输入将会以p的概率丢弃。这个方法是一个避免过拟合的正规化方法。可参见Improving neural networks by preventing co-adaptation of feature detectors
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Dropout() > x = torch.Tensor{{<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>}, {<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">7</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>}} > <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x) <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">14</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> [torch.DoubleTensor of dimension <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>x4] > <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward(x) <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">10</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">0</span> [torch.DoubleTensor of dimension <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>x4]</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li><li style="box-sizing: border-box; padding: 0px 5px;">12</li><li style="box-sizing: border-box; padding: 0px 5px;">13</li></ul>
可以看出,上面一些值被丢弃。
View
这个主要是改变网络输出的tensor的sizes,就是reshape一下。
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.View(sizes)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
如果是用view(-1)则可特别地用作minibatch的输入。
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">x = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>) <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> i = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">do</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">for</span> j = <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">do</span> x[i][j] = (i-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>)*<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">4</span>+j <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span> <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">end</span> <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(nn.View(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">8</span>):forward(x)) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--[[ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 ]]</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li><li style="box-sizing: border-box; padding: 0px 5px;">9</li><li style="box-sizing: border-box; padding: 0px 5px;">10</li><li style="box-sizing: border-box; padding: 0px 5px;">11</li></ul>
作为minibatch的输入!
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;">> input = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>) > minibatch = torch.Tensor(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>, <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">3</span>) > m = nn.View(-<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">1</span>):setNumInputDims(<span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>) > <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">print</span>(#m:forward(minibatch)) <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">5</span> <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">6</span> [torch.LongStorage of size <span class="hljs-number" style="color: rgb(0, 102, 102); box-sizing: border-box;">2</span>] <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--每一行就是一个example的结果。</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li><li style="box-sizing: border-box; padding: 0px 5px;">8</li></ul>
Normalize
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.Normalize(p, [eps])</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
L-p范式进行归一化,eps默认是1e-10,防止输入全为0时除0的情况。
MM
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.MM(transA,transB)</code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li></ul>
transA/transB 为true时,则对应的矩阵进行转置。
如果是3维矩阵,则第一维被理解成batches的num,转置只会对后面的两维进行。
<code class="language-lua hljs has-numbering" style="display: block; padding: 0px; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal; background: transparent;"><span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span> = nn.MM(<span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">true</span>, <span class="hljs-keyword" style="color: rgb(0, 0, 136); box-sizing: border-box;">false</span>) A = torch.randn(b,n,m) B = torch.randn(b,n,p) c = <span class="hljs-built_in" style="color: rgb(102, 0, 102); box-sizing: border-box;">module</span>:forward({A, B}) <span class="hljs-comment" style="color: rgb(136, 0, 0); box-sizing: border-box;">--[[ c是 b * m * p的 ]]</span></code><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul><ul class="pre-numbering" style="box-sizing: border-box; position: absolute; width: 50px; top: 0px; left: 0px; margin: 0px; padding: 6px 0px 40px; border-right-width: 1px; border-right-style: solid; border-right-color: rgb(221, 221, 221); list-style: none; text-align: right; background-color: rgb(238, 238, 238);"><li style="box-sizing: border-box; padding: 0px 5px;">1</li><li style="box-sizing: border-box; padding: 0px 5px;">2</li><li style="box-sizing: border-box; padding: 0px 5px;">3</li><li style="box-sizing: border-box; padding: 0px 5px;">4</li><li style="box-sizing: border-box; padding: 0px 5px;">5</li><li style="box-sizing: border-box; padding: 0px 5px;">6</li><li style="box-sizing: border-box; padding: 0px 5px;">7</li></ul>
简单层一个常见的Module是Add,放在下篇博客讲。因为会用简单层的Add模块讲手动挡训练演示。