Recurrent Model of Visual Attention

September 21, 2015 by Nicholas Leonard

In this blog post, I want to discuss how we at Element-Research implemented the recurrent attention model (RAM) described in [1]. Not only were we able to reproduce the paper, but we also made of bunch of modular code available in the process. You can reproduce the RAM on the MNIST dataset using this training script. We will use snippets of that script throughout this post. You can then evaluate your trained models using the evaluation script.

The paper describes a RAM that can be applied to image classification datasets. The model is designed in such a way that it has a bandwidth limited sensor of the input image. As an example, if the input image is of size 28x28 (height x width), the RAM may only be able to sense an area of size 8x8 at any given time-step. These small sensed areas are refered to as glimpses.

SpatialGlimpse

Actually, the paper's glimpse sensor is little bit more complicated than what we described above, but nevertheless simple. Keeping in step with the modular spirit of Torch7's nn package, we created a SpatialGlimpse module.

module = nn.SpatialGlimpse(size, depth, scale)

Basically, if you feed it an image, say the classic 3x512x512 Lenna image:

lenna

And you run this code on it:

require 'dpnn'
require 'image'

img = image.lena()
loc = torch.Tensor{0,0} -- 0.0 is the center of the image

sg = nn.SpatialGlimpse(64, 3, 2)

output = sg:forward{img, loc}
print(output:size()) --   9 x 64 x 64


-- unroll the glimpse onto its different depths
outputs = torch.chunk(output, 3)
display = image.toDisplayTensor(outputs)
image.save("glimpse-output.png", display)

You will end up with the following (unrolled) glimpse :

glimpse

While the input is a 3x512x512 image (786432 scalars), the output is very small : 9x64x64 (36864 scalars), or about 5% the size of the original image. Since the glimse here has a depth=3 (i.e. it uses 3 patches), each successive patch is scale times the size of the prevous one. So we end up with a high-resolution patch of a small area, and successively lower-resolution (i.e. downscaled) patches of larger areas. This carries fascinating parallels to how our own human attention mechanism may possibly work. Quite often, we humans can vividly see the details of what we focus our attention on, while nevertheless maintaining a blurry sense of the outskirts.

Although not necessary for reproducing this paper, the SpatialGlimpse can also be partially backpropagated through. While the gradient w.r.t. the img tensor can be obtained, the gradients w.r.t. the location (i.e. the x,y coordinates of the glimpse) will be zero. This is because the glimpse operation cannot be differentiated w.r.t. the location. Which brings us to the main difficulty of attention models : how can we teach the network to glimpse at the right locations?

REINFORCE algorithm

Some attention models use a fully differentiable attention mechanism, like the recent DRAW paper [2]. But the RAM model uses a non-differentiable attention mechanism. Specifically, it uses a the REINFORCE algorithm [3]. This algorithm allows one to train stochastic units through reinforcement learning.

The REINFORCE algorithm is very powerful as it can be used to optimize stochastic units (conditioned on an input), such that they minimize an objective function (i.e. a reward function). Unlike backpropagation [4], this objective need not be differentiable.

The RAM model uses the REINFORCE algorithm to train the locator network :

-- actions (locator)
locator = nn.Sequential()
locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h")))

The input is the previous recurrent hidden state h[t-1]. During training, the output is sampled from a normal distribution with fixed standard deviation. The mean is conditioned on h[t-1] through an affine transform (i.e. Linear module). During evaluation, instead of sampling from the distribution, the output is taken to be the input, i.e. the mean.

The ReinforceNormal module implements the REINFORCE algorithm for the normal distribution. Unlike most Modules, ReinforceNormal ignores the gradOutput when backward is called. This is because the units it embodies are in fact stochastic. So then how does it generate gradInputs when the backward is called? It uses the REINFORCE algorithm, which requires that a reward function be defined. The reward used by the paper is very simple, and yet not differentiable (eq. 1) :

R = I(y=t)

where R is the raw reward, I(x) is 1 when x is true and 0 otherwise (see indicator function), y is the predicted class, t is the target class. Or in Lua :

R = (y==t) and 1 or 0

The REINFORCE algorithm requires that we differentiate the probability density or mass function (PDF/PMF) of the distribution w.r.t. the parameters. So given the following variables :

  • f : normal probability density function
  • x : the sampled values (i.e. ReinforceNormal.output)
  • u : mean (input to ReinforceNormal)
  • s : standard deviation (ReinforceNormal.stdev)

the derivative of log normal w.r.t. mean u is :

d ln(f(x,u))   (x - u)
------------ = -------
     d u         s^2

So where does d ln(f(x,u,s) / d u fit in with the reward? Well, in order to obtain a gradient of the reward R w.r.t. to input u, we apply the following equation (eq. 2, i.e. the REINFORCE algorithm):

d R                 d ln(f(x,u)) 
--- = a * (R - b) * --------------
d u                      d u 

where

  • a (alpha) is just a scaling factor like the learning rate ; and
  • b is a baseline reward used to reduce the variance of the gradient.
  • f is the PMF/PDF
  • u is the parameter w.r.t. which you want to get a gradient
  • x is the sampled value.

In the paper, they take b to be the expected reward E[R]. They approximate the expectation by making b a (conditionally independent) parameter of the model . For each example, the mean square error between R and b is minimized by backpropagation. The beauty of doing this, instead of, say, making b a moving average of R, is that the baseline reward b learns at the same rate as the rest of the model.

We decided to implement the reward = a * (R - b) part of eq. 2 within a VRClassReward criterion, which also implements the paper's varianced-reduced classification reward function (eq. 1):

vcr = nn.VRClassReward(module [, scale, criterion])

The nn package was primarily built for backpropagation, so we had to find a not-too-hacky way of broadcasting the reward to the different Reinforce modules. We did this by having the criterion take the module as argument, and adding the Module:reinforce(reward) method. The latter allows the Reinforce modules (like ReinforceNormal) to hold on to the criterion's broadcasted reward for later. Later being when backward is called and the Reinforce module(s) compute(s) the gradInput (i.e. d R / d U) using the REINFORCE algorithm (eq. 2). And then nn is happy, because with the gradInput, it can continue the backpropagation from the Reinforce module to its predecessors.

So to summarize, if you can differentiate the PMF/PDF w.r.t. to its parameters, then you can use the REINFORCE algorithm on it. We have already implemented modules for the categorical and bernoulli distributions :

Recurrent Attention Model

rva

Okay, so we discussed the glimpse module and the REINFORCE algorithm, lets talk about the recurrent attention model. We can divide the model into its respective components :

The location sensor. Its inputs are the x, y coordinates of the current glimpse location, so the network knows where its looking at each time-step :

locationSensor = nn.Sequential()
locationSensor:add(nn.SelectTable(2))
locationSensor:add(nn.Linear(2, opt.locatorHiddenSize))
locationSensor:add(nn[opt.transfer]())

The glimpse sensor is what it's looking at :

glimpseSensor = nn.Sequential()
glimpseSensor:add(nn.DontCast(nn.SpatialGlimpse(opt.glimpsePatchSize, opt.glimpseDepth, opt.glimpseScale):float(),true))
glimpseSensor:add(nn.Collapse(3))
glimpseSensor:add(nn.Linear(ds:imageSize('c')*(opt.glimpsePatchSize^2)*opt.glimpseDepth, opt.glimpseHiddenSize))
glimpseSensor:add(nn[opt.transfer]())

The glimpse network is where the glimpse and location sensor mix through a hidden layer to form the input layer of the Recurrent Neural Network (RNN) :

glimpse = nn.Sequential()
glimpse:add(nn.ConcatTable():add(locationSensor):add(glimpseSensor))
glimpse:add(nn.JoinTable(1,1))
glimpse:add(nn.Linear(opt.glimpseHiddenSize+opt.locatorHiddenSize, opt.imageHiddenSize))
glimpse:add(nn[opt.transfer]())
glimpse:add(nn.Linear(opt.imageHiddenSize, opt.hiddenSize))

The RNN is the where the glimpse network and the recurrent layer come together. The output of the rnn module is the hidden state h[t] where t indexes the time-step. We used the rnn package to build it :

-- rnn recurrent layer
recurrent = nn.Linear(opt.hiddenSize, opt.hiddenSize)

-- recurrent neural network
rnn = nn.Recurrent(opt.hiddenSize, glimpse, recurrent, nn[opt.transfer](), 99999)

We have already seen the locator network above, but here it is again :

-- actions (locator)
locator = nn.Sequential()
locator:add(nn.Linear(opt.hiddenSize, 2))
locator:add(nn.HardTanh()) -- bounds mean between -1 and 1
locator:add(nn.ReinforceNormal(2*opt.locatorStd)) -- sample from normal, uses REINFORCE learning rule
locator:add(nn.HardTanh()) -- bounds sample between -1 and 1
locator:add(nn.MulConstant(opt.unitPixels*2/ds:imageSize("h")))

The locator's task is to sample the next location l[t] (or action) given the previous hidden state h[t-1], i.e. the previous output of the rnn. For the first step, we use h[0] = 0 (a tensor of zeros) for the previous hidden state. We should also bring your attention to the opt.unitPixels variable, as it is very important and not specified in the paper. This variable basically outlines how far (in pixels) near the borders the center of each glimpse can reach with respect to the center. So a value of 13 (the default) means that the center of the glimpse can be anywhere between the 2rd and 27th pixel (for a 1x28x28 MNIST example). So glimpses of the corner will have fewer zero-padding values then if opt.unitPixels = 14.

We need to encapsulate the rnn and locator into a module that would capture the essence of the RAM model. So we decided to implement a module that took the full image as input, and output the sequence of hidden states h. As its name implies, this module is a general purpose RecurrentAttention module :

attention = nn.RecurrentAttention(rnn, locator, opt.rho, {opt.hiddenSize})

It's so general-purpose that you can use it with an LSTM module, and different glimpse or locator modules. As long as these respectively preserve the same {input, action} -> output interfaces (glimpse) and use the REINFORCE algorithm (locator), of course.

Next we build an agent by stacking a classifier on top of the RecurrentAttention module. The classifier's input is the last hidden state h[T], where T is the total number of time-steps (i.e. glimpses to take):

-- model is a reinforcement learning agent
agent = nn.Sequential()
agent:add(nn.Convert(ds:ioShapes(), 'bchw'))
agent:add(attention)

-- classifier :
agent:add(nn.SelectTable(-1))
agent:add(nn.Linear(opt.hiddenSize, #ds:classes()))
agent:add(nn.LogSoftMax())

You might recall that the REINFORCE algorithm requires a baseline reward b? This is where that happens (yes, it requires a little bit of nn kung-fu):

-- add the baseline reward predictor
seq = nn.Sequential()
seq:add(nn.Constant(1,1))
seq:add(nn.Add(1))
concat = nn.ConcatTable():add(nn.Identity()):add(seq)
concat2 = nn.ConcatTable():add(nn.Identity()):add(concat)

-- output will be : {classpred, {classpred, basereward}}
agent:add(concat2)

At this point, the model is ready for training. The agent instance is effectively the RAM model detailed in the paper.

Training

Training is pretty straight-forward after that. The script can be launched with the default hyper-parameters. It should take 1-2 runs to find a validation minima that reproduces the MNIST test results of the paper. We are still working on reproducing the Translated and Cluttered MNIST results.

Results

In the paper, they get 1.07% error on MNIST with 7 glimpses. We got 0.85% error after 853 epochs of training (early-stopped on the validation set of course). Here are some glimpse sequences. Note how the first frames start at the same location :

locations

And this is what the model sees for those sequences:

glimpses

So basically, REINFORCE did a good job of teaching the module to focus its attention on different regions given different inputs. A failure mode (bad) would have been to see the model do this :

fail

In that particular case, the opt.unitPixels was set to 6 instead of 12 (so the attention was limited to a small region in the center of the image). You don't want this to happen because it means the attention is conditionally independent of the input (i.e. dumb).

Here are some results for the Translated MNIST dataset :

translatedmnist

For this dataset, the images are of size 1x60x60 where each image contains a randomly placed 1x28x28 MNIST digit. The 3x12x12 glimpse uses a depth of 3 scales where each successive patch is twice the height and width of the previous one. After 683 epochs of training on the Translatted MNIST dataset, using 7 glimpses, we obtain 0.92% error. The paper reaches 1.22% and 1.2% error for 6 and 8 glimpses, respectively. The exact command used to obtain those results:

th examples/recurrent-visual-attention.lua --cuda --dataset TranslatedMnist --unitPixels 26 --learningRate 0.001 --glimpseDepth 3 --maxTries 200 --stochastic --glimpsePatchSize 12

Note : you can evaluate your models with the evaluation script. It will generate a sample of glimpse sequences and print the confusion matrix results for the test set.

Conclusion

This blog discussed a concrete implementation of a Recurrent Model for Visual Attention using the REINFORCE algorithm. The REINFORCE algorithm is very powerful as it can learn non-differentiable criterions through reinforcement learning. However, like many reinforcement learning algorithms, it does take a long time to converge. Nevertheless, as demonstrated here and in the original paper, the ability to train a model to learn where to focus its attention can provide significant performance improvements.

References

  1. Volodymyr Mnih, Nicolas Heess, Alex Graves, Koray Kavukcuoglu, Recurrent Models of Visual Attention, NIPS 2014
  2. Gregor, Karol, et al., DRAW: A recurrent neural network for image generation, Arxiv 2015
  3. Williams, Ronald J., Simple statistical gradient-following algorithms for connectionist reinforcement learning, Machine learning 8.3-4 (1992): 229-256.
  4. Rumelhart, David E., Geoffrey E. Hinton, and Ronald J. Williams. Learning internal representations by error propagation, No. ICS-8506
comments powered by Disqus