Question

I'm having a hard time trying to derive the maths behind LSTMs and vanishing gradients.

I had a of help from LSTM forward and backward pass, but I got stuck in page 11 from LSTM forward and backward pass.

Given the image:

enter image description here

We can form system of equations, $$ \begin{bmatrix} a^t \\ i^t \\ f^t \\ o^t \\ \end{bmatrix} = \begin{bmatrix} tanh(W_cx^t+U_ch^{t-1}) \\ \sigma(W_iX^t+U_ih^{t-1}) \\ \sigma(W_fx^t+U_fh^{t-1}) \\ \sigma(W_ox^t+U_oh^{t-1}) \\ \end{bmatrix} = \begin{bmatrix} tanh(\hat a^t) \\ \sigma (\hat i^t) \\ \sigma (\hat f^t) \\ \sigma (\hat o^t) \\ \end{bmatrix} $$ We can then represent this as $z$:

$$ z= \begin{bmatrix} \hat a^t \\ \hat i^t \\ \hat f^t \\ \hat o^t \\ \end{bmatrix} = \begin{bmatrix} W^c & U^c \\ W^i & U^i \\ W^f & U^f \\ W^o & U^o \\ \end{bmatrix} * \begin{bmatrix} x^t \\ h^{t-1} \\ \end{bmatrix} $$

We can find out the backprop derivation for $z$ from page 10 from LSTM forward and backward pass

$$ \delta z= \begin{bmatrix} \delta \hat a^t \\ \delta \hat i^t \\ \delta \hat f^t \\ \delta \hat o^t \\ \end{bmatrix} = \begin{bmatrix} \delta a^t \odot (1-tanh^2(\hat a^t)) \\ \delta i^t \odot i^t \odot (1-i^t) \\ \delta f^t \odot f^t \odot (1-f^t) \\ \delta o^t \odot o^t \odot (1-o^t) \\ \end{bmatrix} $$

However the next part at page 11 from LSTM forward and backward pass is where I'm confused.

Given $\delta z$, we need to find $\delta W$, $\delta h^{t-1}$,

1) The author wrote down $\delta I^t = W^T * \delta z$:

If we do some linear algebra variables moves:

$$z = W^T * I^t$$

Multiply both sides with $I^{t^{-1}}$

$$I^{t^{-1}} z = W^T$$

Multiply both sides with $z^{-1}$

$$I^{t^{-1}} = z^{-1} W^T$$

Somehow this doesn't match with the author's formula?

2) Let's ignore 1), and try to solve for $\delta I$

$$ \delta I = \begin{bmatrix} \delta x^t \\ \delta h^{t-1} \\ \end{bmatrix} $$ $$ \delta I = \frac{dE}{dI} = \begin{bmatrix} \frac{dE}{dx^t} \\ \frac{d}{dh^{t-1}} \\ \end{bmatrix} $$

But $\frac{d}{dh^{t-1}}$ depends on a lot of the equations in $z$

Do I solve for them individually and add them up?

Note: $\frac{dE}{d\hat i_t}$ can be found at page 10 from LSTM forward and backward pass

$$h_{t-1}^{i_t}=\frac{dE}{dh_{t-1}^{i_t}}=\frac{dE}{d\hat i_t}\frac{d\hat i_t}{h_{t-1}}=\frac{dE}{d\hat i_t}\frac{d}{dh_{t-1}}i_t(1-i_t)$$

Replace $i_t$ with $\sigma(W_iX^t+U_ih^{t-1})$, replace $\frac{dE}{d\hat i_t}$ with $\delta \hat i_t \delta i_t$

$$=\delta \hat i_t \delta i_t \frac{d}{dh_{t-1}} \sigma(W_iX^t+U_ih^{t-1}) (1-\sigma(W_iX^t+U_ih^{t-1}))$$

It looks solvable, then my question is I would get 4 equations like the above, do I add them all together in the end to get $\delta h^{t-1}$? For example:

$$ \delta I = \frac{dE}{dI} = \begin{bmatrix} \frac{dE}{dx^t} \\ \frac{d}{dh^{t-1}} \\ \end{bmatrix} = \begin{bmatrix} ignore \\ h_{t-1}^{i_t}+h_{t-1}^{a_t}+h_{t-1}^{f_t}+h_{t-1}^{o_t} \\ \end{bmatrix} $$

Since my logic was to find the total error contributed from $h_{t-1}$ so you need to add them together?

3) Finding for $W$ looks even like a bigger task, however, I'm not sure where to start on this?

4) How does this relate to the error carousel? I mean after derivation of all the weights and $h_{t-1}$, I'm not sure how this leads to avoidance of vanishing gradients? I read somewhere that the weights are constant 1 or something along the lines like that?

I know this is kinda long, but feel free to ask for clarification if my question does not make sense. I think I've tried to solve for this almost for half a month now.

Appreciate any sort of guidance. Thanks.

Was it helpful?

Solution

Part 1: Derivation

It's important to note that $\delta z \neq z^{-1}$.

So on page 11 when he says he plans to derive $\delta I^t, \delta h^{t-1}$, and $\delta w^t$ he's invoking the chain rule rather than inverse matrix multiplication.

Let's derive $\delta I^t$.

  1. $\delta I^t = \frac{\partial E}{\partial I^t}$ by definition

  2. $\delta I^t = \frac{\partial z^t}{\partial I^t} * \frac{\partial E}{\partial z^t} $ by the chain rule

    2.1.1 Recall that $z^t = W * I^t$

    2.1.2 Therefore $\frac{\partial z^t}{\partial I^t} = W^T$

    2.2.1 $\frac{\partial E}{\partial z^t} = \delta z^t$ by definition

  3. Therefore $\delta I^t = W^T * \delta z^t$

The author continues by noting that since $I^t = \begin{bmatrix}x^t\\h^{t-1}\end{bmatrix}$, that $\delta h^{t-1}$ can be retrieved from $\delta I^t$ where $\delta h^{t-1}$ would be the dx1 terms after the first nx1 terms.

Along the same vein to derive $\delta W^t$,

  1. $\delta W^t = \frac{\partial E}{\partial W^t}$ by definition

  2. $\delta W^t = \frac{\partial E}{\partial z^t} * \frac{\partial z^t}{\partial W^t} $ by the chain rule

    2.1.1 Recall that $z^t = W * I^t$

    2.1.2 Therefore $\frac{\partial z^t}{\partial W^t} = (I^t)^T$

    2.2.1 $\frac{\partial E}{\partial z^t} = \delta z^t$ by definition

  3. Therefore $\delta W^t = \delta z^t * (I^t)^T$

Part 2: Constant Error Carousels

The following explanation draws heavily from the 14th slide of this lecture, and from the LSTM section of this blog, which I would highly recommend referencing. Consider this portion of my answer the TL;DR version of these two links.

The Constant Error Carousel (CEC) as you might well know is the magic of the LSTM in that it prevents vanishing gradients. It's denoted as follows:

$c_{t+1} = c_t *$ forget gate + new input $*$ input gate

In the case of regular RNNs during backpropagation, the derivative of an activation function, such as a sigmoid, will be less than one. Therefore over time, the repeated multiplication of that value against the weights $f'(x) * W$ will lead to a vanishing gradient.

In the case of an LSTM, we only multiply the cell state by a forget gate, which acts as both the weights and the activation function for the cell state. As long as that forget gate equals one, the information from the previous cell state passes through unchanged.

So in our case, the parameters we derived $h_t$ and $h_{t-1}$ are the hidden states which as part of the input vector $I^t$, will be filtered through the forget gate, and affect the cell state accordingly. The $\delta W$ we solved for is the parameter we use to update all our weights/gates.

Licensed under: CC-BY-SA with attribution
Not affiliated with datascience.stackexchange
scroll top