Sitemap

Neural Turing Machines Explained

Building blocks for Memory Augmented Neural Networks used for One-Shot Learning

7 min readNov 28, 2022

--

In my previous parts, I have discussed the basics of One-Shot learning alongside how, using Transfer Learning & Contrastive loss function, Siamese Network implements One-Shot Learning. It’s time to go a little complex, Neural Turing Machines.

Sounds like some missile !!

Neural Turing Machines themselves don’t implement One-Shot Learning but act as building blocks for MANNs (Memory Augmented NNs are used for One-Shot Learning). Hence, we will be discussing NTMs and later MANNs in my next. NTMs work on the concept of adding explicit memory to a Neural Network. So whenever a Neural Network wishes to access some information, it goes to the memory and fetches relevant information to use.

Easy Peasy !!

Let’s try to understand this jargon in parts with a series of questions

What is a Turing Machine?

So, dating back to World War II, Alan Turing came up with the idea of creating a machine to solve algorithms. How?

The machine structure is simple i.e.

  • An infinite array with possibly 3 values per index/cell: 0,1 or Blank.
  • A head/controller is present above some cell initially which can perform read, write or update operations on the cell & then move left or right to other cells.
https://www.qwizbowl.com/post/qwiz5-quizbowl-essentials-turing-machine

How the controller decides when to read/write/update which cell and with what values? We give a set of instructions to the controller and depending upon the conditions matched, the controller performs actions accordingly. So a set of instructions can be

Set every 2nd cell to 0 and move left

Fill blank values to 1 and move right

You can follow this amazing blog for a deeper dive into Turing Machines

Though the idea looks simple, Alan Turing proved that any problem that is computable can be computed by a Turing machine (the system we discussed above).

What is explicit memory in Neural Networks?

Explicit memory, in general, is external information present in a system to assist you with tasks. Your computer has it in terms of hard drives, You have it in terms of the knowledge you gain and store in your memory over years of experience.

On a similar note, even Neural Networks can build up a memory for themselves which they can use whenever required. This memory is a 2d array where each row stores information about some particular feature of the training dataset. If you are aware of transformers and attention mechanisms, the Attention matrix calculated can be taken as an explicit memory

So, can we define Neural Turing Machines now?

Pretty much yes, it’s is a combo of Turing Machines + Explicit memory. What actually happening is

  • The Controller gets an input (as in the original Turing machine, the controller uses to shift to different cells of the array). Assume the input to be some cell only.
  • It reads relevant memory sections, using the input & memory, gives output and even updates/writes in the memory if any new information comes up in the input and moves to following input using read & write heads
Press enter or click to view image in full size
Neural Turing Machine, https://arxiv.org/abs/1410.5401

All good, but what actually is a ‘Controller’ in terms of NTM?

It’s nothing but a neural network. Hence, the external input is fed to this network which reads memory, giving an external output and updating memory if required. This sounds similar to LSTMs !!

So far so good. Let’s go a little deeper !

How are we going to train this ‘Controller’? As the memory is a 2d array if we directly specify row_id & col_id to read from/write to, this isn’t differentiable and hence model training is not possible. How to make this row_id, and col_id selection differentiable?

Weighted memory reading/writing

What we would be doing is we won’t be reading/writing to any specific row but to all the rows assigning some weights. Now, as we are using weights, training is possible as the gradient of the entire process can be taken. These weights assigned to each row is called an Attention vector as give an idea of which row is more important compared to the other. This vector is normalized.

Note: we would be using a subscript ‘t’ with terms in the below formulas representing the value for that variable at a particular timestamp as values get updated from time to time. Hence ‘t’ means the current timestamp, and ‘t-1’ means the previous timestamp.

The real, complex question is how this weight/attention vector is generated.

This involves 4 step

  • Content-based addressing

Depending on the similarity with the external input, the weights for the Attention vector are assigned to each memory row using the below formula

Here,

β = amplification factor (constant only)

K [kₜ , Mₜ[j]]= Cosine similarity between kₜ (input at timestamp t) and jᵗʰ Memory row (at timestamp t). K represents the cosine similarity function.

Where cosine similarity

Press enter or click to view image in full size
  • Location-based addressing

So, this step majorly aims at including past information alongside the new information we have generated using memory+input. How much should we retain from the past? the interpolation gate ‘g’ helps us with that. If you are getting confused, remember the forget gate from LSTM. This is something similar. The mathematical representation :

wₜ = gₜ X wₜ + (1-gₜ) X wₜ₋₁

Here, gₜ acts as the interpolation gate, and wₜ₋₁ is the past timestamp Attention vector using which is deciding how much past Attention we would be keeping with us in the current timestamp.

  • Shift weighting

This is a bit complicated. This step tries to shift focus amongst the different memory rows in the explicit memory. It does so by implementing a rotational shift of weights. For example, if the current weighting focuses entirely on a single location, a rotation of 1 would shift the focus to the next location. For this, the controller has a shift weighting function S which helps in performing the rotational shift of weights. The formula for the operation is

wₜ(i) = ᴺΣⱼ₌₀ wₜ(j) * Sₜ(i-j)

We will try to understand this complicated equation

t= timestamp

i = iᵗʰ weight index in the Attention Vector we are preparing

j = Goes from 0-N in the summation

Sₜ(i-j) = The weight shifting function’s value for (i-j)

Hence, the updated wₜ(i) is the summation of the weighted sum of the Attention Vector for different memory rows. How are we giving weights? using the weight shifting function i.e. S.

So let’s consider an example. Assume N=5 and we wish to update weights for i=3, then

W[3] = w[0]*S(3) + w[1]*S(2) + w[2]*S(1) + w[3]*S(0)+ w[4]*S(-1)

Now depending on the weight shifting function S, the Attention Vector at index=3 will be calculated.

If you have noticed, the above formula is nothing but a convolution operation (remember CNNs). This can cause leakage or dispersion of weightings over time if the shift weighting is not sharp. For example, if S(-1), S(0), S(1)=0.05, 0.9, 0.05, the rotation will transform a weighting focused at a single point into one slightly blurred over three points. To avoid such ‘blurry’ results, we would be doing the last operation

  • Sharpening

To avoid blurring due to the convolution, the last of the 4 operations we would be performing is sharpening for which we need a constant y≥ 1. The final Attention vector =

wₜ(i) = wₜ(i)ʸ/Σⱼwₜ(j)ʸ

Where j goes from 0-N (where N is the total rows in the memory matrix)

So, this is how we generate the Attention Vector !!

Now let’s quickly walk through how to read & write operation done by NTM

Read

Read is nothing but a weighted sum of memory rows we have in the memory matrix. The weights? The Attention vector we calculated above

readₜ = Σ wₜ(i) x Memoryₜ(i)

Write

Writing involves two major steps. First is to erase and then add to the memory matrix

Erase

Memoryₜ(i) = Memoryₜ₋₁ (i)x [1 —wₜ(i) x eₜ]

Here, eₜ is an additional vector (eraser vector) with all values between 0–1. the elements of a memory location are reset to zero only if both the weighting at the location and the erase element are one; if either the weighting or the erase is zero, the memory is left unchanged.

Add

To add action to the memory matrix also, the controller uses an additional vector aₜ. The formula being

Memoryₜ (i)= Memoryₜ(i) + wₜ(i) x aₜ

Both eₜ and aₜ are produced by the controller only. How? It hasn’t been (or I might have missed it) mentioned in the paper. So, this is how to read, write & memory matrix creation happens in Neural Turing Machine.

With this, it’s a wrap. We will be discussing Memory Augmented NN for One-Shot Learning next based on NMT.

--

--