06 Training Recurrent Neural Networks - PAI-yoonsung/lstm-paper GitHub Wiki
The most common methods to train recurrent neural networks are Backpropagation Through Time (BPTT) [62, 74, 75] and Real-Time Recurrent Learning (RTRL) [75, 76], whereas BPTT is the most common method.
μν μ κ²½λ§μ νμ΅μν€κΈ° μν κ°μ₯ μΌλ°μ μΈ λ©μλλ Backpropagation Through Time (BPTT) κ³Ό Real-Time Recurrent Learning (RTRL) μκ³ λ¦¬μ¦μ΄λ€. (BPTTκ° μ μΌ ννκ² μ¬μ©λλ€)
The main difference between BPTT and RTRL is the way the weight changes are calculated.
BPTT μ RTRL μ μ£Όμ μ°¨μ΄μ μ κ°μ€μΉμ λ³νκ° μ΄λ»κ² κ³μ°λλλμ΄λ€.
The original formulation of LSTM-RNNs used a combination of BPTT and RTRL.
LSTM-RNNs μ μ΄μ°½κΈ° κ³μ°μμ BPTT μ RTRL μ μ‘°ν©μ μ¬μ©νλ€.
Therefore we cover both learning algorithms in short.
κ·Έλ¬λ―λ‘, μ΄ λ νμ΅ μκ³ λ¦¬μ¦μ λν΄ κ°λ΅νκ² μ΄ν΄λ³΄κ² λ€.
6.1 Backpropagation Through Time
The BPTT algorithm makes use of the fact that, for a finite period of time, there is an FFNN with identical behaviour for every RNN.
BPTT μκ³ λ¦¬μ¦μ μ νλ μκ°λμ λ§€λ²μ RNNμμ λμΌν νλμ νλ μμ ν μ κ²½λ§μ΄ μλ€λ μ μ μ΄μ©νλ€.
To obtain this FFNN, we need to unfold the RNN in time.
FFNN μ μ»κΈ° μν΄μ , RNN μ μ μκ°μ νΌμ³μΌ νλ€.
Figure 9a shows a simple, fully recurrent neural network with a single two-neuron layer.
κ·Έλ¦Ό 9aλ κ°λ¨νκ³ , 2κ°μ λ¨μΌ λ΄λ° λ μ΄μ΄λ‘ λ μμ ν μ°κ²°λ μν μ κ²½λ§μ 보μ¬μ€λ€.
The corresponding feed-forward neural network, shown in Figure 9b, requires a separate layer for each time step with the same weights for all layers.
κ·Έλ¦Ό 9bμμ 보μ΄λ μ΄μ μμνλ μμ ν μ κ²½λ§μ, λ§€ μ€νλ§λ€ λͺ¨λ λ μ΄μ΄λ€μ λν΄ κ°μ κ°μ€μΉλ₯Ό κ°λ λΆλ¦¬λ λ μ΄μ΄λ₯Ό μꡬνλ€.
If weights are identical to the RNN, both networks show the same behaviour.
λ§μ½ κ°μ€μΉκ° RNN κ³Ό λμΌνλ€λ©΄ , λ μ κ²½λ§μ κ°μ νλμ 보μ΄κ² λλ€.
The unfolded network can be trained using the backpropagation algorithm described in Section 4.
νΌμ³μ§ λ μ΄μ΄λ μΉμ 4μ λ¬μ¬λ μμ ν μκ³ λ¦¬μ¦μ μ¬μ©νμ¬ νλ ¨λ μ μλ€.
At the end of a training sequence, the network is unfolded in time.
λ§μ§λ§ νλ ¨ λ¨κ³μμ, μ κ²½λ§μ μ μκ°μ νΌμ³μ§κ² λλ€.
The error is calculated for the output units with existing target values using some chosen error measure.
μλ¬λ λͺλͺ μ νλ μλ¬ μΈ‘μ λ°©μμ μ¬μ©νμ¬ μ‘΄μ¬νλ λͺ©ν λ³μλ€κ³Ό μΆλ ₯ μ λμ λνμ¬ κ³μ°λλ€.
Then, the error is injected backwards into the network and the weight updates for all time steps calculated.
κ·Έλ¬κ³ λλ©΄, μλ¬λ λ€νΈμν¬ νλ°©μμ μ£Όμ λκ³ , λ§€ λ°λ³΅λ§λ€ κ³μ°λ μλ¬μ λν κ°μ€μΉκ° κ°±μ λλ€.
The weights in the recurrent version of the network are updated with the sum of its deltas over all time steps.
μ κ²½λ§μ μν λ²μ μ κ°μ€μΉλ€μ λͺ¨λ νμμ€νμ΄ μ’ λ£λ μ΄νμ λΈνκ°λ€μ ν©μ λνμ¬ κ°±μ λλ€.
Figure 9 λ 2κ°μ λ΄λ° λ μ΄μ΄λ₯Ό κ°μ§ κ°λ¨ν μμ μ°κ²° μ κ²½λ§μ 보μ¬μ€λ€. κ°κ°μ νμ μ€νμ λνμ¬ λΆλ¦¬λ λ μ΄μ΄λ₯Ό κ°λ aμ λμΌνμ§λ§ λ°λ³΅ μ΄ν νΌμ³μ§ λ€νΈμν¬λ bμ κ·Έλ €μ Έμλ€. νμλ μμ ν μ κ²½λ§μ ννν κ²μ΄λ€.
t λ μμ 6λ² κ³΅μμ μν΄ μ£Όμ΄μ§κ² λλ€. λν 7λ² κ³΅μμ κ°μ€μΉκ° μ μ©λ μ λ ₯κ°λ ν¨κ» μ μ©λλ€.
where v β U β©Pre (u) and i β I, the set of input units.
v β U β© Pre (u) μ i β I λ μ λ ₯ μ λλ€μ μ§ν©μ λνλΈλ€.
Note that the inputs to u at time Ο +1 are of two types: the environmental input that arrives at time Ο +1 via the input units, and the recurrent output from all non-input units in the network produced at time Ο .
Ο +1 νμμΌ λ, u λ₯Ό ν₯ν μ λ ₯κ°λ€μ λ κ°μ§ νμ μ΄ μλ€: μ λ ₯ μ λμ μν΄ Ο +1 νμμ λμ°©νλ νκ²½ μ λ ₯κ³Ό Ο μνμμ λ€νΈμν¬μμ μμ±λ λͺ¨λ non-μ λ ₯ μ λλ€λ‘λΆν° μ€λ μν μΆλ ₯μ΄λ€.
If the network is fully connected, then U β© Pre (u) is equal to the set U of non-input units.
λ§μ½ λ€νΈμν¬κ° μμ μ°κ²°μ΄λΌλ©΄, U β© Pre(u) λ non-μ λ ₯ μ λμ μ§ν© Uμ λμΌνλ€.
Let T(Ο ) be the set of non-input units for which, at time Ο , the output value yu(Ο ) of the unit u β T(Ο ) should match some target value du(Ο ).
T(Ο) λ₯Ό Ο νμμ non-μ λ ₯ μ λλ€μ μ§ν©μ΄λΌκ³ λ λ, u β T(Ο ) μ λμ μΆλ ₯κ° yu(Ο ) λ μΌλΆ λͺ©μ λ³μ du(Ο ) μ μΌμΉν΄μΌ νλ€.
The cost function is the summed error E_total(t', t) for the epoch t', t' + 1, . . . , t, which we want to minimise using a learning algorithm.
λΉμ© ν¨μλ μν¬ν¬ t' μμ t κΉμ§μ λν μλ¬μ μ΄ν© E_total(t', t) μ΄λ€. μ°λ¦¬λ μ΄ μλ¬μ μ΄ν©μ νμ΅ μκ³ λ¦¬μ¦μ μ¬μ©νμ¬ μ΅μννλ κ²μ λͺ©νλ‘ νλ€.
μλ¬μ μ΄ν©μ 8λ² κ³΅μμ μν΄ μ μλλ€. νμ Ο μμμ μλ¬ E(Ο) λ λͺ©μ ν¨μλ‘ squared error λ₯Ό μ¬μ©νμ¬ μ μλλ€. (9λ² κ³΅μ) νμ Ο μμμ non-μ λ ₯ μ λ u μ μλ¬ e_u(Ο ) λ 10λ² κ³΅μμ ν΅ν΄ μ μλλ€.
κ°μ€μΉλ₯Ό μ μ©μν€κΈ° μν΄, Ο νμμ non-μ λ ₯ μ λ u μ μλ¬ μ νΈ Ο_u(Ο) 11λ² κ³΅μκ³Ό κ°μ΄ μ μλλ€.
Ο_u λ₯Ό νκ² λλ©΄, 12λ² μκ³Ό λμΌν κ²μ μ»κ²λλ€.
t' μκ°μ μμ ν κ³μ° μ΄ν, κ°μ€μΉλ μν λ²μ λ€νΈμν¬μ βW[u,v]λ₯Ό κ°±μ νλ€. μ΄λ λͺ¨λ νμ μ€νμ λν΄ μμνλ κ°μ€μΉ κ°±μ λ€μ ν©μΉλ κ²μΌλ‘ λ§μΉκ² λλ€.
BPTT μ λν λ μμΈν μ 보λ [74],[62],[76] μ λμμλ€.
6.2 Real-Time Recurrent Learning
The RTRL algorithm does not require error propagation.
RTRL μκ³ λ¦¬μ¦μμλ μλ¬ μ νκ° νμνμ§ μλ€.
All the information necessary to compute the gradient is collected as the input stream is presented to the network.
κΈ°μΈκΈ°λ₯Ό κ³μ°νκΈ° μν΄ νμν λͺ¨λ μ 보λ λ€νΈμν¬λ‘ μ 곡λλ μ λ ₯ μ€νΈλ¦Όμ μν΄ λͺ¨μμ§κ² λλ€.
This makes a dedicated training interval obsolete.
μ΄λ νΉμ ν νλ ¨ κ°κ²©μ νμνμ§ μκ² ν©λλ€.
The algorithm comes at significant computational cost per update cycle, and the stored information is non-local; i.e., we need an additional notion called sensitivity of the output, which weβll explain later.
μ΄ μκ³ λ¦¬μ¦μ λ§€ κ°±μ μ¬μ΄ν΄λ§λ€ μλΉν κ³μ° μμμ΄ λ€κ³ , μ μ₯λλ μ λ³΄κ° non-λ‘컬 νλ€. μ¦, μΆλ ₯μ λ―Όκ°λλΌ λΆλ¦¬μ°λ μΆκ°μ μΈ κ°λ μ΄ νμνκ³ , μ΄λ λμ€μ μ€λͺ ν κ²μ΄λ€.
Nevertheless, the memory required depends only on the size of the network and not on the size of the input.
λ°λ©΄μ, λ©λͺ¨λ¦¬ μꡬλμ μ λ ₯μ ν¬κΈ°κ° μλ μ€μ§ λ€νΈμν¬μ ν¬κΈ°μ μμ‘΄νλ€.
Following the notation from the previous section, we will now define for the network units v β I βͺ U and u, k β U, and the time steps t' β€ Ο β€ t.
λ€μμ μ΄μ μΉμ μ λμλ κ°λ μΌλ‘, λ€νΈμν¬ μ λ v β I βͺ U and u, k β U κ·Έλ¦¬κ³ νμ μ€ν t' β€ Ο β€ t λ₯Ό μ μνλ€.
Unlike BPTT, in RTRL we assume the existence of a label d_k(Ο ) at every time Ο (given that it is an online algorithm) for every non-input unit k, so the training objective is to minimise the overall network error, which is given at time step Ο by
BPTT μλ λ¬λ¦¬, RTRLμ λͺ¨λ non-μ λ ₯ μ λ kμ λνμ¬ λͺ¨λ νμ Ο(μ¨λΌμΈ μκ³ λ¦¬μ¦μμ μ£Όμ΄μ§) μμ λΌλ²¨ d_k(Ο) μ μ‘΄μ¬λ₯Ό μΆμ νλ€. μ¦, νμ΅ λͺ©μ μ λ€μμ 곡μμ ν΅ν΄ νμ μ€ν Ομμ μ£Όμ΄μ§λ μ λ°μ μΈ λ€νΈμν¬ μλ¬λ₯Ό μ΅μννλ κ²μ΄λ€.
We conclude from Equation 8 that the gradient of the total error is also the sum of the gradient for all previous time steps and the current time step:
μ°λ¦¬λ μ΄ μλ¬μ κΈ°μΈκΈ°λ₯Ό κ³μ°νλ λ°©μ μ 8μ΄ λͺ¨λ μ΄μ νμ μ€νκ³Ό νμ¬ νμ μ€νμ κΈ°μΈκΈ°λ€μ ν©μ΄λΌλ κ²°λ‘ μ μ§μ μ μλ€.
During presentation of the time series to the network, we need to accumulate the values of the gradient at each time step. Thus, we can also keep track of the weight changes βWu,v. After presentation, the overall weight change for W[u,v] is then given by
λ€νΈμν¬μ μκ³μ΄μ΄ μ‘΄μ¬νλ λμ, κ° νμμ€νμ κΈ°μΈκΈ° κ°λ€μ λμ μμΌμΌ νλ€. κ·Έλ¬λ―λ‘, κΈ°μΈκΈ°μ λ³ν βWu,v λ₯Ό μ«μκ°μΌ νλ€. μ‘΄μ¬ μ΄ν, μ λ°μ μΈ W[u,v] μ λν κ°μ€μΉ λ³νλ λ€μκ³Ό κ°μ΄ μ£Όμ΄μ§λ€.
κ°μ€μΉ λ³νλ₯Ό μ»κΈ° μν΄μλ, μ κ·Έλ¦Όμμ λ λ²μ§Έ μμ μ¬μ©ν΄ κ³μ°ν΄μΌ νλ€. κ° νμμ€ν t μ λν΄μ, κ²½μ¬νκ°λ²μ λ°λΌ λ°©μ μμ νΌμΉ μ΄ν μ, λ°©μ μ 9λ₯Ό μ μ©μν€λ κ²μΌλ‘, μλμ 14λ²κ³Ό κ°μ 곡μμ μ»μ μ μλ€.
Since the error ek(Ο ) = dk(Ο ) β yk(Ο ) is always known, we need to find a way to calculate the second factor only. We define the quantity
μλ¬ e_k(Ο ) = d_k(Ο ) β y_k(Ο ) λ μΈμ λ μλ €μ ΈμκΈ° λλ¬Έμ, λ λ²μ§Έ μμ(?) λ§ μ°ΎμΌλ©΄ λλ€. κ·Έ λ°©μ μμ λ€μ 15λ² κ³΅μκ³Ό κ°μ΄ μ μλλ€.
which measures the sensitivity of the output of unit k at time Ο to a small change in the weight W[u,v], in due consideration of the effect of such a change in the weight over the entire network trajectory from time t' to t.
ν΄λΉ μμ νμ Ο μμ μ λ k μ κ°μ€μΉ W[u,v] μμμμ μμ λ³νμ λν μΆλ ₯μ κ°λλ₯Ό μΈ‘μ νλ€. νμ t' μμ t λ°©ν₯μΌλ‘ κ°λ λ€νΈμν¬ μ λ°μ κ±ΈμΉ κ°μ€μΉμμμ μ΄λ¬ν λ³ν μν₯μ κ³ λ €ν(?)
The weight W[u,v] does not have to be connected to unit k, which makes the algorithm non-local.
κ°μ€μΉ W[u,v] λ μ λ k μ λ°λμ μ°κ²°λ νμλ μλ€. μ΄ μ¬μ€μ μ΄ μκ³ λ¦¬μ¦μ non-local νκ² λ§λ λ€.
Local changes in the network can have an effect anywhere in the network.
λ€νΈμν¬μμ λ‘컬 λ³νλ λ€νΈμν¬ μ΄λμλ μν₯μ λ―ΈμΉ μ μλ€.
In RTRL, the gradient information is forward-propagated. Using Equations 6 and 7, the output y_k(t + 1) at time step t + 1 is given by
RTRL μμ κΈ°μΈκΈ° μ 보λ μ λ°©μΌλ‘ μ νλλ€. λ°©μ μ 6κ³Ό 7μ μ¬μ©νμ¬, νμ μ€ν t + 1 μμμ μΆλ ₯ y_k(t + 1) λ 16λ²κ³Ό κ°μ΄ μ£Όμ΄μ§λ€.
κ°μ€μΉκ° μ μ©λ μΈνμ ν¨κ» μ μ©μν¬ κ²½μ°, 17λ² κ³΅μκ°μ΄ λλ€.
λ°©μ μ 15, 16, 17 μ μ°¨μ΄λ₯Ό λλ κ²μΌλ‘, λͺ¨λ νμ μ€ν β₯ t + 1 κ²°κ³Όλ₯Ό λ€μκ³Ό κ°μ΄ κ³μ°ν μ μλ€.
Ξ΄_uk λ Kronecker delta(?) μ΄κ³ , μ΄λ
Ξ΄_uk = 1 if u = k
0 if otherwise
μ¬κΈ°μ, λ€νΈμν¬μ μ΄κΈ° μνλ κ°μ€μΉμ λν΄ κΈ°λ₯μ μμ‘΄μ±μ κ°κ³ μμ§ μλ€λ κ°μ νμ, 첫 λ²μ§Έ νμ μ€νμ λ―ΈλΆμ λ€μκ³Ό κ°λ€.
Equation 18 shows how p^k_uv(t + 1) can be calculated in terms of p^k_uv(t).
λ°©μ μ 18λ²μ p^k_uv(t) μ κ΄μ μμ μ΄λ»κ² p^k_uv(t + 1) κ° κ³μ°λλ μ§λ₯Ό 보μ¬μ€λ€.
In this sense, the learning algorithm becomes incremental, so that we can learn as we receive new inputs (in real time), and we no longer need to perform back-propagation through time.
μ΄ κ°λ μμ, νμ΅ μκ³ λ¦¬μ¦μ μ¦λΆμ΄ λκ³ (?), μλ‘μ΄ μ λ ₯μ μ€μκ°μΌλ‘ λ°κΈ° λλ¬Έμ νμ΅μν¬ μ μλ€. κ·Έλ¬λ―λ‘, λ μ΄μ μκ° λ΄λ΄(?) μμ νλ₯Ό μνν νμκ° μμ΄μ§λ€.
Knowing the initial value for p^k_uv at time t' from Equation 19, we can recursively calculate the quantities p^k_uv for the first and all subsequent time steps using Equation 18.
곡μ 19λ‘λΆν° μκ° t' μμμ p^k_uv μ λν μ΄κΈ°κ°μ μλ κ²μ, 18λ² λ°©μ μμ μ¬μ©νλ 첫 λ²μ§Έμ νμμ λͺ¨λ νμ μ€νμ λν΄ p^k_uvμ μμ μνμ μΌλ‘ κ³μ°ν μ μλ€.
Note that p^k_uv(Ο ) uses the values of W[u,v] at t', and not values in-between t' and Ο.
p^k_uv(Ο) λ νμ t' μμμ W[u,v] λ₯Ό μ¬μ©νλ κ²μ΄μ§, νμ t'μ Ο μ¬μ΄μ κ°μ μ¬μ©νλ κ²μ΄ μλλΌλ μ μ μ£Όμνμ¬μΌ νλ€.
Combining these values with the error vector e(Ο ) for that time step, using Equation 14, we can finally calculate the negative error gradient βWE(Ο ).
λ°©μ μ 14λ₯Ό μ¬μ©νμ¬ νμ μ€ν Ο μ λν μλ¬ λ²‘ν° e(Ο ) κ°λ€μ ν©μΉλ κ²μΌλ‘, λ§μΉ¨λ΄ λ€κ±°ν°λΈ μλ¬ κΈ°μΈκΈ° βWE(Ο ) λ₯Ό κ³μ°νλ κ²μ΄ κ°λ₯νλ€.
The final weight change for W[u,v] can be calculated using Equations 14 and 13.
W[u,v]μ λν μ΅μ’ κ°μ€μΉ λ³κ²½μ λ°©μ μ 14μ 13μ ν΅ν΄ κ³μ°ν μ μλ€.
A more detailed description of the RTRL algorithm is given in [75] and [76].
RTRL μκ³ λ¦¬μ¦μ λν λμ± μμΈν λ¬μ¬λ [75], [76] μ λμμλ€.
dictionary
recurrent: μννλ unroll: νλ€ stream: (μ¬μ μ μλ―Έ)νλ¦, (λ°μ΄ν° κ΄μ )λ°μ΄ν°,ν¨ν·,λΉνΈ λ±μ μΌλ ¨μ μ°μμ±μ κ°λ νλ¦ dedicated: νΉμ ν interval: κ°κ²© obsolete: μΈλͺ¨μλ, νμμλ notion: κ°λ assume: μΆμ νλ€ equation: λ°©μ μ trajectory: λ°©ν₯, κΆ€μ incremental: μ¦λΆ(?) perform: μ΄ννλ€, μννλ€