In this guide, I will show you how to code a Convolutional Long Short-Term Memory (ConvLSTM) using an autoencoder (seq2seq) architecture for frame prediction using the MovingMNIST dataset (but custom datasets can also easily be integrated).
This method was originally used for precipitation forecasting at NIPS in 2015, and has been extended extensively since then with methods such as PredRNN, PredRNN++, Eidetic 3D LSTM, and so on…
We also use the pytorch-lightning framework, which is great for removing a lot of the boilerplate code and easily integrate 16-bit training and multi-GPU training.
Before starting, we will briefly outline the libraries we are using:
python=3.6.8
torch=1.1.0
torchvision=0.3.0
pytorch-lightning=0.7.1
matplotlib=3.1.3
tensorboard=1.15.0a20190708
Download the dataloader script from the following repo tychovdo/MovingMNIST.
This dataset was originally developed and described here, and it contains 10000 sequences each of length 20 with frame size 64 x 64 showing 2 digits moving in various trajectories (and overlapping).
Something to note beforehand is the inherent randomness of the digit trajectories. We do expect that this will become a major hurdle for the model we are about to describe, and we also note that newer approaches such as Variational Autoencoders might be a more efficient model for this type of task.
The specific model type we will be using is called a seq2seq model, which is typically used for NLP or time-series tasks (it was actually implemented in the Google Translate engine in 2016).
The original papers on seq2seq are Sutskever et al., 2014 and Cho et al., 2014.
In its simplest configuration, the seq2seq model takes a sequence of items as input (such as words, word embeddings, letters, etc.) and outputs another sequence of items. For machine translation, the input could be a sequence of Spanish words and the output would be the English translation.
We can separate the seq2seq model into three parts, which are
a) Encoder (encodes the input list)
b) Encoder embedding vector (the final embedding of the entire input sequence)
c) Decoder (decodes the embedding vector into the output sequence)
For our machine translation example, this would mean:
Hopefully part a) and part c) are somewhat clear to you. Arguably the most tricky part in terms of intuition for the seq2seq model is the encoder embedding vector. How do you define this vector exactly?
Before you move any further, I highly recommend the following excellent blog post on RNN/LSTM. Understanding LSTM’s intimately is an essential prerequisite for most seq2seq models!
Here are the equations for the regular LSTM cell:
where ∘ denotes the Hadamard product.
So let's assume you fully understand what an LSTM cell is and how cell states and hidden states work. Typically the encoder and decoder in seq2seq models consist of LSTM cells, such as the following figure:
Several extensions to the vanilla seq2seq model exist; the most notable being the Attention module.
Having discussed the seq2seq model, let's turn our attention to the task of frame prediction!
Frame prediction is inherently different from the original tasks of seq2seq such as machine translation. This is due to the fact, that RNN modules (LSTM) in the encoder and decoder use fully-connected layers to encode and decode word embeddings (which are represented as vectors).
Once we are dealing with frames we have 2D tensors, and to encode and decode these in a sequential nature we need an extension of the original LSTM seq2seq models.
This is where Convolutional LSTM (ConvLSTM) comes in. Presented at NIPS in 2015, ConvLSTM modifies the inner workings of the LSTM mechanism to use the convolution operation instead of simple matrix multiplication. Let's write our new equations for the ConvLSTM cells:
∗ denotes the convolution operation and ∘ denotes the Hadamard product like before.
Can you spot the subtle difference between these equations and regular LSTM? We simply replace the multiplications in the four gates between
a) weight matrices and input (Wₓ xₜ with Wₓ ∗ Xₜ) and
b) weight matrices and previous hidden state (Wₕ hₜ₋₁ with Wₕ ∗ Hₜ₋₁).
Otherwise, everything remains the same.
If you prefer not to dive into the above equations, the primary thing to note is the fact that we use convolutions (kernel) to process our input images to derive feature maps rather than vectors derived from fully-connected layers.
One of the most difficult things when designing frame prediction models (with ConvLSTM) is defining how to produce the frame predictions. We list two methods here (but others do also exist):
In this tutorial, we will focus on number 1 — especially since it can produce any number of predictions in the future without having to change the architecture completely. Furthermore, if we are to predict many steps in the future option 2 becomes increasingly computationally expensive.
For our ConvLSTM implementation, we use the PyTorch implementation from ndrplz
It looks as follows:
Hopefully, you can see how the equations defined earlier are written in the above code for the forward pass.
The specific architecture we use looks as follows:
We use two ConvLSTM cells for both the encoder and the decoder (encoder_1_convlstm, encoder_2_convlstm, decoder_1_convlstm, decoder_2_convlstm).
Our final ConvLSTM cell (decoder_2convlstm) outputs _nf feature maps for each predicted frame (12, 10, 64, 64, 64).
As we are essentially doing regression (predicting pixel values), we need to transform these feature maps into actual predictions similar to what you do in classical image classification.
To achieve this we implement a 3D-CNN layer. The 3D CNN layer does the following:
Finally, as we have transformed the pixel values into [0, 1] we use a sigmoid function to turn our 3D CNN activations into [0, 1].
And that is basically it!
Now we define the python implementation for the seq2seq model:
Maybe you are already aware of the excellent library pytorch-lightning, which essentially takes all the boiler-plate engineering out of machine learning when using PyTorch, such as the following commands: optimizer.zero_grad(), optimizer.step().
It also standardizes training modules and enables easy multi-GPU functionality and mixed-precision training for Volta architecture GPU cards.
There is so much functionality available in pytorch-lightning, and I will try to demonstrate the workflow I have created, which I think works fairly well.
Most of the functionality of class MovingMNISTLightning is fairly self-explanatory. Here is the overall workflow:
When we actually run our main.py script we can define several relevant parameters. For example, if we want to run with 2 GPUs, mixed-precision and batch_size = 16 we simply type:
python main.py --n_gpus=2 --use_amp=True --batch_size=16
Feel free to experiment with various configurations!
When we run the main.py script we automatically spin up a tensorboard session using multiprocessing, and here you can track the performance of our model iteratively and also see the visualization of our predictions every 250 global step.
Thanks for reading this article! I hope you enjoyed it!
Please reach out either here or on Twitter if you have any questions or comments regarding the above paper. You can also find more tutorials on my webpage https://holmdk.github.io/.