Simple LSTM

A few weeks ago I released some code on Github to help people understand how LSTM’s work at the implementation level. The forward pass is well explained elsewhere and is straightforward to understand, but I derived the backprop equations myself and the backprop code came without any explanation whatsoever. The goal of this post is to explain the so called backpropagation through time in the context of LSTM’s.

If you feel like anything is confusing, please post a comment below or submit an issue on Github.

Note: this post assumes you understand the forward pass of an LSTM network, as this part is relatively simple. Please read this great intro paper if you are not familiar with this, as it contains a very nice intro to LSTM’s. I follow the same notation as this paper so I recommend reading having the tutorial open in a separate browser tab for easy reference while reading this post.

Introduction

The forward pass of an LSTM node is defined as follows:

g(t)i(t)f(t)o(t)s(t)h(t)======ϕ(Wgxx(t)+Wghh(t1)+bg)σ(Wixx(t)+Wihh(t1)+bi)σ(Wfxx(t)+Wfhh(t1)+bf)σ(Woxx(t)+Wohh(t1)+bo)g(t)i(t)+s(t1)f(t)s(t)o(t) g(t)=ϕ(Wgxx(t)+Wghh(t−1)+bg)i(t)=σ(Wixx(t)+Wihh(t−1)+bi)f(t)=σ(Wfxx(t)+Wfhh(t−1)+bf)o(t)=σ(Woxx(t)+Wohh(t−1)+bo)s(t)=g(t)∗i(t)+s(t−1)∗f(t)h(t)=s(t)∗o(t)

By concatenating the  x(t) x(t) and  h(t1) h(t−1) vectors as follows:

xc(t)=[x(t),h(t1)] xc(t)=[x(t),h(t−1)]

we can re-write parts of the above as follows:

g(t)i(t)f(t)o(t)====ϕ(Wgxc(t)+bg)σ(Wixc(t)+bi)σ(Wfxc(t)+bf)σ(Woxc(t)+bo). g(t)=ϕ(Wgxc(t)+bg)i(t)=σ(Wixc(t)+bi)f(t)=σ(Wfxc(t)+bf)o(t)=σ(Woxc(t)+bo).

Suppose we have a loss  l(t) l(t) that we wish to minimize at every time step  t t that depends on the hidden layer  h h and the label  y y at the current time via a loss function  f f:

l(t)=f(h(t),y(t)) l(t)=f(h(t),y(t))

where  f f can be any differentiable loss function, such as the Euclidean loss:

l(t)=f(h(t),y(t))=h(t)y(t)2. l(t)=f(h(t),y(t))=‖h(t)−y(t)‖2.

Our ultimate goal in this case is to use gradient descent to minimize the loss  L L over an entire sequence of length  T T:

L=t=1Tl(t). L=∑t=1Tl(t).

Let’s work through the algebra of computing the loss gradient:

dLdw dLdw

where  w w is a scalar parameter of the model (for example it may be an entry in the matrix  Wgx Wgx). Since the loss  l(t)=f(h(t),y(t)) l(t)=f(h(t),y(t)) only depends on the values of the hidden layer  h(t) h(t) and the label  y(t) y(t), we have by the chain rule:

dLdw=t=1Ti=1MdLdhi(t)dhi(t)dw dLdw=∑t=1T∑i=1MdLdhi(t)dhi(t)dw

where  hi(t) hi(t) is the scalar corresponding to the  i i’th memory cell’s hidden output and  M M is the total number of memory cells. Since the network propagates information forwards in time, changing  hi(t) hi(t) will have no effect on the loss prior to time  t t, which allows us to write:

dLdhi(t)=s=1Tdl(s)dhi(t)=s=tTdl(s)dhi(t) dLdhi(t)=∑s=1Tdl(s)dhi(t)=∑s=tTdl(s)dhi(t)

For notational convenience we introduce the variable  L(t) L(t) that represents the cumulative loss from step  t t onwards:

L(t)=s=ts=Tl(s) L(t)=∑s=ts=Tl(s)

such that  L(1) L(1) is the loss for the entire sequence. This allows us to re-write the above equation as:

dLdhi(t)=s=tTdl(s)dhi(t)=dL(t)dhi(t) dLdhi(t)=∑s=tTdl(s)dhi(t)=dL(t)dhi(t)

With this in mind, we can re-write our gradient calculation as:

dLdw=t=1Ti=1MdL(t)dhi(t)dhi(t)dw dLdw=∑t=1T∑i=1MdL(t)dhi(t)dhi(t)dw

Make sure you understand this last equation. The computation of  dhi(t)dw dhi(t)dw follows directly follows from the forward propagation equations presented earlier. We now show how to compute  dL(t)dhi(t) dL(t)dhi(t) which is where the so called backpropagation through time comes into play.

Backpropagation through time

This variable  L(t) L(t) allows us to express the following recursion:

L(t)={l(t)+L(t+1)l(t)ift<Tift=T L(t)={l(t)+L(t+1)ift<Tl(t)ift=T

Hence, given activation  h(t) h(t) of an LSTM node at time  t t, we have that:

dL(t)dh(t)=dl(t)dh(t)+dL(t+1)dh(t) dL(t)dh(t)=dl(t)dh(t)+dL(t+1)dh(t)

Now, we know where the first term on the right hand side  dl(t)dh(t) dl(t)dh(t) comes from: it’s simply the elementwise derivative of the loss  l(t) l(t) with respect to the activations  h(t) h(t) at time  t t. The second term  dL(t+1)dh(t) dL(t+1)dh(t) is where the recurrent nature of LSTM’s shows up. It shows that the we need the next node’s derivative information in order to compute the current current node’s derivative information. Since we will ultimately need to compute  dL(t)dh(t) dL(t)dh(t) for all  t=1,,T t=1,…,T, we start by computing

dL(T)dh(T)=dl(T)dh(T) dL(T)dh(T)=dl(T)dh(T)

and work our way backwards through the network. Hence the term backpropagation through time. With these intuitions in place, we jump into the code.

Code

We now present the code that performs the backprop pass through a single node at time  1tT 1≤t≤T. The code takes as input:

  • top_diff_h  =dL(t)dh(t)=dl(t)dh(t)+dL(t+1)dh(t) =dL(t)dh(t)=dl(t)dh(t)+dL(t+1)dh(t)
  • top_diff_s  =dL(t+1)ds(t) =dL(t+1)ds(t).

And computes:

  • self.state.bottom_diff_s  =dL(t)ds(t) =dL(t)ds(t)
  • self.state.bottom_diff_h  =dL(t)dh(t1) =dL(t)dh(t−1)

whose values will need to be propagated backwards in time. The code also adds derivatives to:

  • self.param.wi_diff  =dLdWi =dLdWi
  • self.param.bi_diff  =dLdbi =dLdbi

since recall that we must sum the derivatives from each time step:

dLdw=t=1Ti=1MdL(t)dhi(t)dhi(t)dw. dLdw=∑t=1T∑i=1MdL(t)dhi(t)dhi(t)dw.

Also, note that we use:

  • dxc  =dLdxc(t) =dLdxc(t)

where we recall that  xc(t)=[x(t),h(t1)] xc(t)=[x(t),h(t−1)]. Without any further due, the code:

def top_diff_is(self, top_diff_h, top_diff_s):
    # notice that top_diff_s is carried along the constant error carousel
    ds = self.state.o * top_diff_h + top_diff_s
    do = self.state.s * top_diff_h
    di = self.state.g * ds
    dg = self.state.i * ds
    df = self.s_prev * ds

    # diffs w.r.t. vector inside sigma / tanh function
    di_input = (1. - self.state.i) * self.state.i * di 
    df_input = (1. - self.state.f) * self.state.f * df 
    do_input = (1. - self.state.o) * self.state.o * do 
    dg_input = (1. - self.state.g ** 2) * dg

    # diffs w.r.t. inputs
    self.param.wi_diff += np.outer(di_input, self.xc)
    self.param.wf_diff += np.outer(df_input, self.xc)
    self.param.wo_diff += np.outer(do_input, self.xc)
    self.param.wg_diff += np.outer(dg_input, self.xc)
    self.param.bi_diff += di_input
    self.param.bf_diff += df_input       
    self.param.bo_diff += do_input
    self.param.bg_diff += dg_input       

    # compute bottom diff
    dxc = np.zeros_like(self.xc)
    dxc += np.dot(self.param.wi.T, di_input)
    dxc += np.dot(self.param.wf.T, df_input)
    dxc += np.dot(self.param.wo.T, do_input)
    dxc += np.dot(self.param.wg.T, dg_input)

    # save bottom diffs
    self.state.bottom_diff_s = ds * self.state.f
    self.state.bottom_diff_x = dxc[:self.param.x_dim]
    self.state.bottom_diff_h = dxc[self.param.x_dim:]

Details

The forward propagation equations show that modifying  s(t) s(t) affects the loss  L(t) L(t) by directly changing the values of  h(t) h(t) as well as  h(t+1) h(t+1). However, modifying  s(t) s(t) affects  L(t+1) L(t+1) only by modifying  h(t+1) h(t+1). Therefore, by the chain rule:

dL(t)dsi(t)====dL(t)dhi(t)dhi(t)dsi(t)+dL(t)dhi(t+1)dhi(t+1)dsi(t)dL(t)dhi(t)dhi(t)dsi(t)+dL(t+1)dhi(t+1)dhi(t+1)dsi(t)dL(t)dhi(t)dhi(t)dsi(t)+dL(t+1)dsi(t)dL(t)dhi(t)dhi(t)dsi(t)+[
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值